agentrun-inner-test 0.0.46__py3-none-any.whl → 0.0.64__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun/__init__.py +34 -1
- agentrun/agent_runtime/__endpoint_async_template.py +40 -0
- agentrun/agent_runtime/api/control.py +1 -1
- agentrun/agent_runtime/endpoint.py +79 -0
- agentrun/credential/api/control.py +1 -1
- agentrun/integration/agentscope/__init__.py +2 -1
- agentrun/integration/agentscope/builtin.py +23 -0
- agentrun/integration/builtin/__init__.py +2 -0
- agentrun/integration/builtin/knowledgebase.py +137 -0
- agentrun/integration/crewai/__init__.py +2 -1
- agentrun/integration/crewai/builtin.py +23 -0
- agentrun/integration/google_adk/__init__.py +2 -1
- agentrun/integration/google_adk/builtin.py +23 -0
- agentrun/integration/langchain/__init__.py +2 -1
- agentrun/integration/langchain/builtin.py +23 -0
- agentrun/integration/langgraph/__init__.py +2 -1
- agentrun/integration/langgraph/builtin.py +23 -0
- agentrun/integration/pydantic_ai/__init__.py +2 -1
- agentrun/integration/pydantic_ai/builtin.py +23 -0
- agentrun/knowledgebase/__client_async_template.py +173 -0
- agentrun/knowledgebase/__init__.py +53 -0
- agentrun/knowledgebase/__knowledgebase_async_template.py +438 -0
- agentrun/knowledgebase/api/__data_async_template.py +414 -0
- agentrun/knowledgebase/api/__init__.py +19 -0
- agentrun/knowledgebase/api/control.py +606 -0
- agentrun/knowledgebase/api/data.py +624 -0
- agentrun/knowledgebase/client.py +311 -0
- agentrun/knowledgebase/knowledgebase.py +748 -0
- agentrun/knowledgebase/model.py +270 -0
- agentrun/memory_collection/__client_async_template.py +178 -0
- agentrun/memory_collection/__init__.py +37 -0
- agentrun/memory_collection/__memory_collection_async_template.py +457 -0
- agentrun/memory_collection/api/__init__.py +5 -0
- agentrun/memory_collection/api/control.py +610 -0
- agentrun/memory_collection/client.py +323 -0
- agentrun/memory_collection/memory_collection.py +844 -0
- agentrun/memory_collection/model.py +162 -0
- agentrun/model/api/control.py +1 -1
- agentrun/sandbox/aio_sandbox.py +11 -4
- agentrun/sandbox/api/control.py +1 -1
- agentrun/sandbox/browser_sandbox.py +2 -2
- agentrun/sandbox/model.py +0 -13
- agentrun/toolset/api/control.py +1 -1
- agentrun/toolset/toolset.py +1 -0
- agentrun/utils/__data_api_async_template.py +1 -0
- agentrun/utils/config.py +12 -0
- agentrun/utils/control_api.py +27 -0
- agentrun/utils/data_api.py +1 -0
- {agentrun_inner_test-0.0.46.dist-info → agentrun_inner_test-0.0.64.dist-info}/METADATA +4 -2
- {agentrun_inner_test-0.0.46.dist-info → agentrun_inner_test-0.0.64.dist-info}/RECORD +53 -34
- {agentrun_inner_test-0.0.46.dist-info → agentrun_inner_test-0.0.64.dist-info}/WHEEL +0 -0
- {agentrun_inner_test-0.0.46.dist-info → agentrun_inner_test-0.0.64.dist-info}/licenses/LICENSE +0 -0
- {agentrun_inner_test-0.0.46.dist-info → agentrun_inner_test-0.0.64.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,624 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file is auto generated by the code generation script.
|
|
3
|
+
Do not modify this file manually.
|
|
4
|
+
Use the `make codegen` command to regenerate.
|
|
5
|
+
|
|
6
|
+
当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。
|
|
7
|
+
使用 `make codegen` 命令重新生成。
|
|
8
|
+
|
|
9
|
+
source: agentrun/knowledgebase/api/__data_async_template.py
|
|
10
|
+
|
|
11
|
+
KnowledgeBase 数据链路 API / KnowledgeBase Data API
|
|
12
|
+
|
|
13
|
+
提供知识库检索功能的数据链路 API。
|
|
14
|
+
Provides data API for knowledge base retrieval operations.
|
|
15
|
+
|
|
16
|
+
根据不同的 provider 类型(ragflow / bailian)分发到不同的实现。
|
|
17
|
+
Dispatches to different implementations based on provider type (ragflow / bailian).
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from abc import ABC, abstractmethod
|
|
21
|
+
from typing import Any, Dict, List, Optional, Union
|
|
22
|
+
|
|
23
|
+
from alibabacloud_bailian20231229 import models as bailian_models
|
|
24
|
+
import httpx
|
|
25
|
+
|
|
26
|
+
from agentrun.utils.config import Config
|
|
27
|
+
from agentrun.utils.control_api import ControlAPI
|
|
28
|
+
from agentrun.utils.data_api import DataAPI, ResourceType
|
|
29
|
+
from agentrun.utils.log import logger
|
|
30
|
+
|
|
31
|
+
from ..model import (
|
|
32
|
+
BailianProviderSettings,
|
|
33
|
+
BailianRetrieveSettings,
|
|
34
|
+
KnowledgeBaseProvider,
|
|
35
|
+
RagFlowProviderSettings,
|
|
36
|
+
RagFlowRetrieveSettings,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class KnowledgeBaseDataAPI(ABC):
|
|
41
|
+
"""知识库数据链路 API 基类 / KnowledgeBase Data API Base Class
|
|
42
|
+
|
|
43
|
+
定义知识库检索的抽象接口,由具体的 provider 实现。
|
|
44
|
+
Defines abstract interface for knowledge base retrieval, implemented by specific providers.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
knowledge_base_name: str,
|
|
50
|
+
config: Optional[Config] = None,
|
|
51
|
+
):
|
|
52
|
+
"""初始化知识库数据链路 API / Initialize KnowledgeBase Data API
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
knowledge_base_name: 知识库名称 / Knowledge base name
|
|
56
|
+
config: 配置 / Configuration
|
|
57
|
+
"""
|
|
58
|
+
self.knowledge_base_name = knowledge_base_name
|
|
59
|
+
self.config = Config.with_configs(config)
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
async def retrieve_async(
|
|
63
|
+
self,
|
|
64
|
+
query: str,
|
|
65
|
+
config: Optional[Config] = None,
|
|
66
|
+
) -> Dict[str, Any]:
|
|
67
|
+
"""检索知识库(异步)/ Retrieve from knowledge base (async)
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
query: 查询文本 / Query text
|
|
71
|
+
config: 配置 / Configuration
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Dict[str, Any]: 检索结果 / Retrieval results
|
|
75
|
+
"""
|
|
76
|
+
raise NotImplementedError("Subclasses must implement retrieve_async")
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def retrieve(
|
|
80
|
+
self,
|
|
81
|
+
query: str,
|
|
82
|
+
config: Optional[Config] = None,
|
|
83
|
+
) -> Dict[str, Any]:
|
|
84
|
+
"""检索知识库(同步)/ Retrieve from knowledge base (async)
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
query: 查询文本 / Query text
|
|
88
|
+
config: 配置 / Configuration
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Dict[str, Any]: 检索结果 / Retrieval results
|
|
92
|
+
"""
|
|
93
|
+
raise NotImplementedError("Subclasses must implement retrieve")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class RagFlowDataAPI(KnowledgeBaseDataAPI):
|
|
97
|
+
"""RagFlow 知识库数据链路 API / RagFlow KnowledgeBase Data API
|
|
98
|
+
|
|
99
|
+
实现 RagFlow 知识库的检索逻辑。
|
|
100
|
+
Implements retrieval logic for RagFlow knowledge base.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
knowledge_base_name: str,
|
|
106
|
+
config: Optional[Config] = None,
|
|
107
|
+
provider_settings: Optional[RagFlowProviderSettings] = None,
|
|
108
|
+
retrieve_settings: Optional[RagFlowRetrieveSettings] = None,
|
|
109
|
+
credential_name: Optional[str] = None,
|
|
110
|
+
):
|
|
111
|
+
"""初始化 RagFlow 知识库数据链路 API / Initialize RagFlow KnowledgeBase Data API
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
knowledge_base_name: 知识库名称 / Knowledge base name
|
|
115
|
+
config: 配置 / Configuration
|
|
116
|
+
provider_settings: RagFlow 提供商设置 / RagFlow provider settings
|
|
117
|
+
retrieve_settings: RagFlow 检索设置 / RagFlow retrieve settings
|
|
118
|
+
credential_name: 凭证名称 / Credential name
|
|
119
|
+
"""
|
|
120
|
+
super().__init__(knowledge_base_name, config)
|
|
121
|
+
self.provider_settings = provider_settings
|
|
122
|
+
self.retrieve_settings = retrieve_settings
|
|
123
|
+
self.credential_name = credential_name
|
|
124
|
+
|
|
125
|
+
async def _get_api_key_async(self, config: Optional[Config] = None) -> str:
|
|
126
|
+
"""获取 API Key(异步)/ Get API Key (async)
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
config: 配置 / Configuration
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
str: API Key
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
ValueError: 凭证名称未设置或凭证不存在 / Credential name not set or credential not found
|
|
136
|
+
"""
|
|
137
|
+
if not self.credential_name:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
"credential_name is required for RagFlow retrieval"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
from agentrun.credential import Credential
|
|
143
|
+
|
|
144
|
+
credential = await Credential.get_by_name_async(
|
|
145
|
+
self.credential_name, config=config
|
|
146
|
+
)
|
|
147
|
+
if not credential.credential_secret:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"Credential '{self.credential_name}' has no secret configured"
|
|
150
|
+
)
|
|
151
|
+
return credential.credential_secret
|
|
152
|
+
|
|
153
|
+
def _get_api_key(self, config: Optional[Config] = None) -> str:
|
|
154
|
+
"""获取 API Key(同步)/ Get API Key (async)
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
config: 配置 / Configuration
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
str: API Key
|
|
161
|
+
|
|
162
|
+
Raises:
|
|
163
|
+
ValueError: 凭证名称未设置或凭证不存在 / Credential name not set or credential not found
|
|
164
|
+
"""
|
|
165
|
+
if not self.credential_name:
|
|
166
|
+
raise ValueError(
|
|
167
|
+
"credential_name is required for RagFlow retrieval"
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
from agentrun.credential import Credential
|
|
171
|
+
|
|
172
|
+
credential = Credential.get_by_name(self.credential_name, config=config)
|
|
173
|
+
if not credential.credential_secret:
|
|
174
|
+
raise ValueError(
|
|
175
|
+
f"Credential '{self.credential_name}' has no secret configured"
|
|
176
|
+
)
|
|
177
|
+
return credential.credential_secret
|
|
178
|
+
|
|
179
|
+
def _build_request_body(self, query: str) -> Dict[str, Any]:
|
|
180
|
+
"""构建请求体 / Build request body
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
query: 查询文本 / Query text
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Dict[str, Any]: 请求体 / Request body
|
|
187
|
+
"""
|
|
188
|
+
if self.provider_settings is None:
|
|
189
|
+
raise ValueError(
|
|
190
|
+
"provider_settings is required for RagFlow retrieval"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
body: Dict[str, Any] = {
|
|
194
|
+
"question": query,
|
|
195
|
+
"dataset_ids": self.provider_settings.dataset_ids,
|
|
196
|
+
"page": 1,
|
|
197
|
+
"page_size": 30,
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
# 添加检索设置 / Add retrieve settings
|
|
201
|
+
if self.retrieve_settings:
|
|
202
|
+
if self.retrieve_settings.similarity_threshold is not None:
|
|
203
|
+
body["similarity_threshold"] = (
|
|
204
|
+
self.retrieve_settings.similarity_threshold
|
|
205
|
+
)
|
|
206
|
+
if self.retrieve_settings.vector_similarity_weight is not None:
|
|
207
|
+
body["vector_similarity_weight"] = (
|
|
208
|
+
self.retrieve_settings.vector_similarity_weight
|
|
209
|
+
)
|
|
210
|
+
if self.retrieve_settings.cross_languages is not None:
|
|
211
|
+
body["cross_languages"] = self.retrieve_settings.cross_languages
|
|
212
|
+
|
|
213
|
+
return body
|
|
214
|
+
|
|
215
|
+
async def retrieve_async(
|
|
216
|
+
self,
|
|
217
|
+
query: str,
|
|
218
|
+
config: Optional[Config] = None,
|
|
219
|
+
) -> Dict[str, Any]:
|
|
220
|
+
"""RagFlow 检索(异步)/ RagFlow retrieval (async)
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
query: 查询文本 / Query text
|
|
224
|
+
config: 配置 / Configuration
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
Dict[str, Any]: 检索结果 / Retrieval results
|
|
228
|
+
"""
|
|
229
|
+
try:
|
|
230
|
+
if self.provider_settings is None:
|
|
231
|
+
raise ValueError(
|
|
232
|
+
"provider_settings is required for RagFlow retrieval"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# 获取 API Key / Get API Key
|
|
236
|
+
api_key = await self._get_api_key_async(config)
|
|
237
|
+
|
|
238
|
+
# 构建请求 / Build request
|
|
239
|
+
base_url = self.provider_settings.base_url.rstrip("/")
|
|
240
|
+
url = f"{base_url}/api/v1/retrieval"
|
|
241
|
+
headers = {
|
|
242
|
+
"Content-Type": "application/json",
|
|
243
|
+
"Authorization": f"Bearer {api_key}",
|
|
244
|
+
}
|
|
245
|
+
body = self._build_request_body(query)
|
|
246
|
+
|
|
247
|
+
# 发送请求 / Send request
|
|
248
|
+
async with httpx.AsyncClient(
|
|
249
|
+
timeout=self.config.get_timeout()
|
|
250
|
+
) as client:
|
|
251
|
+
response = await client.post(url, json=body, headers=headers)
|
|
252
|
+
response.raise_for_status()
|
|
253
|
+
result = response.json()
|
|
254
|
+
logger.debug(f"RagFlow retrieval result: {result}")
|
|
255
|
+
|
|
256
|
+
# 返回结果 / Return result
|
|
257
|
+
data = result.get("data", {})
|
|
258
|
+
|
|
259
|
+
if data == False:
|
|
260
|
+
raise Exception(f"RagFlow retrieval failed: {result}")
|
|
261
|
+
|
|
262
|
+
return {
|
|
263
|
+
"data": data,
|
|
264
|
+
"query": query,
|
|
265
|
+
"knowledge_base_name": self.knowledge_base_name,
|
|
266
|
+
}
|
|
267
|
+
except Exception as e:
|
|
268
|
+
logger.warning(
|
|
269
|
+
"Failed to retrieve from RagFlow knowledge base "
|
|
270
|
+
f"'{self.knowledge_base_name}': {e}"
|
|
271
|
+
)
|
|
272
|
+
return {
|
|
273
|
+
"data": f"Failed to retrieve: {e}",
|
|
274
|
+
"query": query,
|
|
275
|
+
"knowledge_base_name": self.knowledge_base_name,
|
|
276
|
+
"error": True,
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
def retrieve(
|
|
280
|
+
self,
|
|
281
|
+
query: str,
|
|
282
|
+
config: Optional[Config] = None,
|
|
283
|
+
) -> Dict[str, Any]:
|
|
284
|
+
"""RagFlow 检索(同步)/ RagFlow retrieval (async)
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
query: 查询文本 / Query text
|
|
288
|
+
config: 配置 / Configuration
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
Dict[str, Any]: 检索结果 / Retrieval results
|
|
292
|
+
"""
|
|
293
|
+
try:
|
|
294
|
+
if self.provider_settings is None:
|
|
295
|
+
raise ValueError(
|
|
296
|
+
"provider_settings is required for RagFlow retrieval"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# 获取 API Key / Get API Key
|
|
300
|
+
api_key = self._get_api_key(config)
|
|
301
|
+
|
|
302
|
+
# 构建请求 / Build request
|
|
303
|
+
base_url = self.provider_settings.base_url.rstrip("/")
|
|
304
|
+
url = f"{base_url}/api/v1/retrieval"
|
|
305
|
+
headers = {
|
|
306
|
+
"Content-Type": "application/json",
|
|
307
|
+
"Authorization": f"Bearer {api_key}",
|
|
308
|
+
}
|
|
309
|
+
body = self._build_request_body(query)
|
|
310
|
+
|
|
311
|
+
# 发送请求 / Send request
|
|
312
|
+
with httpx.Client(timeout=self.config.get_timeout()) as client:
|
|
313
|
+
response = client.post(url, json=body, headers=headers)
|
|
314
|
+
response.raise_for_status()
|
|
315
|
+
result = response.json()
|
|
316
|
+
logger.debug(f"RagFlow retrieval result: {result}")
|
|
317
|
+
|
|
318
|
+
# 返回结果 / Return result
|
|
319
|
+
data = result.get("data", {})
|
|
320
|
+
|
|
321
|
+
if data == False:
|
|
322
|
+
raise Exception(f"RagFlow retrieval failed: {result}")
|
|
323
|
+
|
|
324
|
+
return {
|
|
325
|
+
"data": data,
|
|
326
|
+
"query": query,
|
|
327
|
+
"knowledge_base_name": self.knowledge_base_name,
|
|
328
|
+
}
|
|
329
|
+
except Exception as e:
|
|
330
|
+
logger.warning(
|
|
331
|
+
"Failed to retrieve from RagFlow knowledge base "
|
|
332
|
+
f"'{self.knowledge_base_name}': {e}"
|
|
333
|
+
)
|
|
334
|
+
return {
|
|
335
|
+
"data": f"Failed to retrieve: {e}",
|
|
336
|
+
"query": query,
|
|
337
|
+
"knowledge_base_name": self.knowledge_base_name,
|
|
338
|
+
"error": True,
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class BailianDataAPI(KnowledgeBaseDataAPI, ControlAPI):
|
|
343
|
+
"""百炼知识库数据链路 API / Bailian KnowledgeBase Data API
|
|
344
|
+
|
|
345
|
+
实现百炼知识库的检索逻辑。
|
|
346
|
+
Implements retrieval logic for Bailian knowledge base.
|
|
347
|
+
"""
|
|
348
|
+
|
|
349
|
+
def __init__(
|
|
350
|
+
self,
|
|
351
|
+
knowledge_base_name: str,
|
|
352
|
+
config: Optional[Config] = None,
|
|
353
|
+
provider_settings: Optional[BailianProviderSettings] = None,
|
|
354
|
+
retrieve_settings: Optional[BailianRetrieveSettings] = None,
|
|
355
|
+
):
|
|
356
|
+
"""初始化百炼知识库数据链路 API / Initialize Bailian KnowledgeBase Data API
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
knowledge_base_name: 知识库名称 / Knowledge base name
|
|
360
|
+
config: 配置 / Configuration
|
|
361
|
+
provider_settings: 百炼提供商设置 / Bailian provider settings
|
|
362
|
+
retrieve_settings: 百炼检索设置 / Bailian retrieve settings
|
|
363
|
+
"""
|
|
364
|
+
KnowledgeBaseDataAPI.__init__(self, knowledge_base_name, config)
|
|
365
|
+
ControlAPI.__init__(self, config)
|
|
366
|
+
self.provider_settings = provider_settings
|
|
367
|
+
self.retrieve_settings = retrieve_settings
|
|
368
|
+
|
|
369
|
+
async def retrieve_async(
|
|
370
|
+
self,
|
|
371
|
+
query: str,
|
|
372
|
+
config: Optional[Config] = None,
|
|
373
|
+
) -> Dict[str, Any]:
|
|
374
|
+
"""百炼检索(异步)/ Bailian retrieval (async)
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
query: 查询文本 / Query text
|
|
378
|
+
config: 配置 / Configuration
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
Dict[str, Any]: 检索结果 / Retrieval results
|
|
382
|
+
"""
|
|
383
|
+
try:
|
|
384
|
+
if self.provider_settings is None:
|
|
385
|
+
raise ValueError(
|
|
386
|
+
"provider_settings is required for Bailian retrieval"
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
workspace_id = self.provider_settings.workspace_id
|
|
390
|
+
index_ids = self.provider_settings.index_ids
|
|
391
|
+
|
|
392
|
+
# 构建检索请求 / Build retrieve request
|
|
393
|
+
request_params: Dict[str, Any] = {
|
|
394
|
+
"query": query,
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
# 添加检索设置 / Add retrieve settings
|
|
398
|
+
if self.retrieve_settings:
|
|
399
|
+
if self.retrieve_settings.dense_similarity_top_k is not None:
|
|
400
|
+
request_params["dense_similarity_top_k"] = (
|
|
401
|
+
self.retrieve_settings.dense_similarity_top_k
|
|
402
|
+
)
|
|
403
|
+
if self.retrieve_settings.sparse_similarity_top_k is not None:
|
|
404
|
+
request_params["sparse_similarity_top_k"] = (
|
|
405
|
+
self.retrieve_settings.sparse_similarity_top_k
|
|
406
|
+
)
|
|
407
|
+
if self.retrieve_settings.rerank_min_score is not None:
|
|
408
|
+
request_params["rerank_min_score"] = (
|
|
409
|
+
self.retrieve_settings.rerank_min_score
|
|
410
|
+
)
|
|
411
|
+
if self.retrieve_settings.rerank_top_n is not None:
|
|
412
|
+
request_params["rerank_top_n"] = (
|
|
413
|
+
self.retrieve_settings.rerank_top_n
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
# 获取百炼客户端 / Get Bailian client
|
|
417
|
+
client = self._get_bailian_client(config)
|
|
418
|
+
|
|
419
|
+
# 对每个 index_id 进行检索并合并结果 / Retrieve from each index and merge results
|
|
420
|
+
all_nodes: List[Dict[str, Any]] = []
|
|
421
|
+
for index_id in index_ids:
|
|
422
|
+
request_params["index_id"] = index_id
|
|
423
|
+
request = bailian_models.RetrieveRequest(**request_params)
|
|
424
|
+
response = await client.retrieve_async(workspace_id, request)
|
|
425
|
+
logger.debug(f"Bailian retrieve response: {response}")
|
|
426
|
+
|
|
427
|
+
if (
|
|
428
|
+
response.body
|
|
429
|
+
and response.body.data
|
|
430
|
+
and response.body.data.nodes
|
|
431
|
+
):
|
|
432
|
+
for node in response.body.data.nodes:
|
|
433
|
+
all_nodes.append({
|
|
434
|
+
"text": (
|
|
435
|
+
node.text if hasattr(node, "text") else None
|
|
436
|
+
),
|
|
437
|
+
"score": (
|
|
438
|
+
node.score if hasattr(node, "score") else None
|
|
439
|
+
),
|
|
440
|
+
"metadata": (
|
|
441
|
+
node.metadata
|
|
442
|
+
if hasattr(node, "metadata")
|
|
443
|
+
else None
|
|
444
|
+
),
|
|
445
|
+
})
|
|
446
|
+
|
|
447
|
+
return {
|
|
448
|
+
"data": all_nodes,
|
|
449
|
+
"query": query,
|
|
450
|
+
"knowledge_base_name": self.knowledge_base_name,
|
|
451
|
+
}
|
|
452
|
+
except Exception as e:
|
|
453
|
+
logger.warning(
|
|
454
|
+
"Failed to retrieve from Bailian knowledge base "
|
|
455
|
+
f"'{self.knowledge_base_name}': {e}"
|
|
456
|
+
)
|
|
457
|
+
return {
|
|
458
|
+
"data": f"Failed to retrieve: {e}",
|
|
459
|
+
"query": query,
|
|
460
|
+
"knowledge_base_name": self.knowledge_base_name,
|
|
461
|
+
"error": True,
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
def retrieve(
|
|
465
|
+
self,
|
|
466
|
+
query: str,
|
|
467
|
+
config: Optional[Config] = None,
|
|
468
|
+
) -> Dict[str, Any]:
|
|
469
|
+
"""百炼检索(同步)/ Bailian retrieval (async)
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
query: 查询文本 / Query text
|
|
473
|
+
config: 配置 / Configuration
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
Dict[str, Any]: 检索结果 / Retrieval results
|
|
477
|
+
"""
|
|
478
|
+
try:
|
|
479
|
+
if self.provider_settings is None:
|
|
480
|
+
raise ValueError(
|
|
481
|
+
"provider_settings is required for Bailian retrieval"
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
workspace_id = self.provider_settings.workspace_id
|
|
485
|
+
index_ids = self.provider_settings.index_ids
|
|
486
|
+
|
|
487
|
+
# 构建检索请求 / Build retrieve request
|
|
488
|
+
request_params: Dict[str, Any] = {
|
|
489
|
+
"query": query,
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
# 添加检索设置 / Add retrieve settings
|
|
493
|
+
if self.retrieve_settings:
|
|
494
|
+
if self.retrieve_settings.dense_similarity_top_k is not None:
|
|
495
|
+
request_params["dense_similarity_top_k"] = (
|
|
496
|
+
self.retrieve_settings.dense_similarity_top_k
|
|
497
|
+
)
|
|
498
|
+
if self.retrieve_settings.sparse_similarity_top_k is not None:
|
|
499
|
+
request_params["sparse_similarity_top_k"] = (
|
|
500
|
+
self.retrieve_settings.sparse_similarity_top_k
|
|
501
|
+
)
|
|
502
|
+
if self.retrieve_settings.rerank_min_score is not None:
|
|
503
|
+
request_params["rerank_min_score"] = (
|
|
504
|
+
self.retrieve_settings.rerank_min_score
|
|
505
|
+
)
|
|
506
|
+
if self.retrieve_settings.rerank_top_n is not None:
|
|
507
|
+
request_params["rerank_top_n"] = (
|
|
508
|
+
self.retrieve_settings.rerank_top_n
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
# 获取百炼客户端 / Get Bailian client
|
|
512
|
+
client = self._get_bailian_client(config)
|
|
513
|
+
|
|
514
|
+
# 对每个 index_id 进行检索并合并结果 / Retrieve from each index and merge results
|
|
515
|
+
all_nodes: List[Dict[str, Any]] = []
|
|
516
|
+
for index_id in index_ids:
|
|
517
|
+
request_params["index_id"] = index_id
|
|
518
|
+
request = bailian_models.RetrieveRequest(**request_params)
|
|
519
|
+
response = client.retrieve(workspace_id, request)
|
|
520
|
+
logger.debug(f"Bailian retrieve response: {response}")
|
|
521
|
+
|
|
522
|
+
if (
|
|
523
|
+
response.body
|
|
524
|
+
and response.body.data
|
|
525
|
+
and response.body.data.nodes
|
|
526
|
+
):
|
|
527
|
+
for node in response.body.data.nodes:
|
|
528
|
+
all_nodes.append({
|
|
529
|
+
"text": (
|
|
530
|
+
node.text if hasattr(node, "text") else None
|
|
531
|
+
),
|
|
532
|
+
"score": (
|
|
533
|
+
node.score if hasattr(node, "score") else None
|
|
534
|
+
),
|
|
535
|
+
"metadata": (
|
|
536
|
+
node.metadata
|
|
537
|
+
if hasattr(node, "metadata")
|
|
538
|
+
else None
|
|
539
|
+
),
|
|
540
|
+
})
|
|
541
|
+
|
|
542
|
+
return {
|
|
543
|
+
"data": all_nodes,
|
|
544
|
+
"query": query,
|
|
545
|
+
"knowledge_base_name": self.knowledge_base_name,
|
|
546
|
+
}
|
|
547
|
+
except Exception as e:
|
|
548
|
+
logger.warning(
|
|
549
|
+
"Failed to retrieve from Bailian knowledge base "
|
|
550
|
+
f"'{self.knowledge_base_name}': {e}"
|
|
551
|
+
)
|
|
552
|
+
return {
|
|
553
|
+
"data": f"Failed to retrieve: {e}",
|
|
554
|
+
"query": query,
|
|
555
|
+
"knowledge_base_name": self.knowledge_base_name,
|
|
556
|
+
"error": True,
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
def get_data_api(
|
|
561
|
+
provider: KnowledgeBaseProvider,
|
|
562
|
+
knowledge_base_name: str,
|
|
563
|
+
config: Optional[Config] = None,
|
|
564
|
+
provider_settings: Optional[
|
|
565
|
+
Union[RagFlowProviderSettings, BailianProviderSettings]
|
|
566
|
+
] = None,
|
|
567
|
+
retrieve_settings: Optional[
|
|
568
|
+
Union[RagFlowRetrieveSettings, BailianRetrieveSettings]
|
|
569
|
+
] = None,
|
|
570
|
+
credential_name: Optional[str] = None,
|
|
571
|
+
) -> KnowledgeBaseDataAPI:
|
|
572
|
+
"""根据 provider 类型获取对应的数据链路 API / Get data API by provider type
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
provider: 提供商类型 / Provider type
|
|
576
|
+
knowledge_base_name: 知识库名称 / Knowledge base name
|
|
577
|
+
config: 配置 / Configuration
|
|
578
|
+
provider_settings: 提供商设置 / Provider settings
|
|
579
|
+
retrieve_settings: 检索设置 / Retrieve settings
|
|
580
|
+
credential_name: 凭证名称(RagFlow 需要)/ Credential name (required for RagFlow)
|
|
581
|
+
|
|
582
|
+
Returns:
|
|
583
|
+
KnowledgeBaseDataAPI: 对应的数据链路 API 实例 / Corresponding data API instance
|
|
584
|
+
|
|
585
|
+
Raises:
|
|
586
|
+
ValueError: 不支持的 provider 类型 / Unsupported provider type
|
|
587
|
+
"""
|
|
588
|
+
if provider == KnowledgeBaseProvider.RAGFLOW or provider == "ragflow":
|
|
589
|
+
ragflow_provider_settings = (
|
|
590
|
+
provider_settings
|
|
591
|
+
if isinstance(provider_settings, RagFlowProviderSettings)
|
|
592
|
+
else None
|
|
593
|
+
)
|
|
594
|
+
ragflow_retrieve_settings = (
|
|
595
|
+
retrieve_settings
|
|
596
|
+
if isinstance(retrieve_settings, RagFlowRetrieveSettings)
|
|
597
|
+
else None
|
|
598
|
+
)
|
|
599
|
+
return RagFlowDataAPI(
|
|
600
|
+
knowledge_base_name,
|
|
601
|
+
config,
|
|
602
|
+
provider_settings=ragflow_provider_settings,
|
|
603
|
+
retrieve_settings=ragflow_retrieve_settings,
|
|
604
|
+
credential_name=credential_name,
|
|
605
|
+
)
|
|
606
|
+
elif provider == KnowledgeBaseProvider.BAILIAN or provider == "bailian":
|
|
607
|
+
bailian_provider_settings = (
|
|
608
|
+
provider_settings
|
|
609
|
+
if isinstance(provider_settings, BailianProviderSettings)
|
|
610
|
+
else None
|
|
611
|
+
)
|
|
612
|
+
bailian_retrieve_settings = (
|
|
613
|
+
retrieve_settings
|
|
614
|
+
if isinstance(retrieve_settings, BailianRetrieveSettings)
|
|
615
|
+
else None
|
|
616
|
+
)
|
|
617
|
+
return BailianDataAPI(
|
|
618
|
+
knowledge_base_name,
|
|
619
|
+
config,
|
|
620
|
+
provider_settings=bailian_provider_settings,
|
|
621
|
+
retrieve_settings=bailian_retrieve_settings,
|
|
622
|
+
)
|
|
623
|
+
else:
|
|
624
|
+
raise ValueError(f"Unsupported provider type: {provider}")
|