agentrun-inner-test 0.0.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agentrun-inner-test might be problematic. Click here for more details.

Files changed (154) hide show
  1. agentrun/__init__.py +358 -0
  2. agentrun/agent_runtime/__client_async_template.py +466 -0
  3. agentrun/agent_runtime/__endpoint_async_template.py +345 -0
  4. agentrun/agent_runtime/__init__.py +53 -0
  5. agentrun/agent_runtime/__runtime_async_template.py +477 -0
  6. agentrun/agent_runtime/api/__data_async_template.py +58 -0
  7. agentrun/agent_runtime/api/__init__.py +6 -0
  8. agentrun/agent_runtime/api/control.py +1362 -0
  9. agentrun/agent_runtime/api/data.py +98 -0
  10. agentrun/agent_runtime/client.py +868 -0
  11. agentrun/agent_runtime/endpoint.py +649 -0
  12. agentrun/agent_runtime/model.py +362 -0
  13. agentrun/agent_runtime/runtime.py +904 -0
  14. agentrun/credential/__client_async_template.py +177 -0
  15. agentrun/credential/__credential_async_template.py +216 -0
  16. agentrun/credential/__init__.py +28 -0
  17. agentrun/credential/api/__init__.py +5 -0
  18. agentrun/credential/api/control.py +606 -0
  19. agentrun/credential/client.py +319 -0
  20. agentrun/credential/credential.py +381 -0
  21. agentrun/credential/model.py +248 -0
  22. agentrun/integration/__init__.py +21 -0
  23. agentrun/integration/agentscope/__init__.py +13 -0
  24. agentrun/integration/agentscope/adapter.py +17 -0
  25. agentrun/integration/agentscope/builtin.py +88 -0
  26. agentrun/integration/agentscope/message_adapter.py +185 -0
  27. agentrun/integration/agentscope/model_adapter.py +60 -0
  28. agentrun/integration/agentscope/tool_adapter.py +59 -0
  29. agentrun/integration/builtin/__init__.py +18 -0
  30. agentrun/integration/builtin/knowledgebase.py +137 -0
  31. agentrun/integration/builtin/model.py +93 -0
  32. agentrun/integration/builtin/sandbox.py +1234 -0
  33. agentrun/integration/builtin/toolset.py +47 -0
  34. agentrun/integration/crewai/__init__.py +13 -0
  35. agentrun/integration/crewai/adapter.py +9 -0
  36. agentrun/integration/crewai/builtin.py +88 -0
  37. agentrun/integration/crewai/model_adapter.py +31 -0
  38. agentrun/integration/crewai/tool_adapter.py +26 -0
  39. agentrun/integration/google_adk/__init__.py +13 -0
  40. agentrun/integration/google_adk/adapter.py +15 -0
  41. agentrun/integration/google_adk/builtin.py +88 -0
  42. agentrun/integration/google_adk/message_adapter.py +144 -0
  43. agentrun/integration/google_adk/model_adapter.py +46 -0
  44. agentrun/integration/google_adk/tool_adapter.py +235 -0
  45. agentrun/integration/langchain/__init__.py +31 -0
  46. agentrun/integration/langchain/adapter.py +15 -0
  47. agentrun/integration/langchain/builtin.py +94 -0
  48. agentrun/integration/langchain/message_adapter.py +141 -0
  49. agentrun/integration/langchain/model_adapter.py +37 -0
  50. agentrun/integration/langchain/tool_adapter.py +50 -0
  51. agentrun/integration/langgraph/__init__.py +36 -0
  52. agentrun/integration/langgraph/adapter.py +20 -0
  53. agentrun/integration/langgraph/agent_converter.py +1073 -0
  54. agentrun/integration/langgraph/builtin.py +88 -0
  55. agentrun/integration/pydantic_ai/__init__.py +13 -0
  56. agentrun/integration/pydantic_ai/adapter.py +13 -0
  57. agentrun/integration/pydantic_ai/builtin.py +88 -0
  58. agentrun/integration/pydantic_ai/model_adapter.py +44 -0
  59. agentrun/integration/pydantic_ai/tool_adapter.py +19 -0
  60. agentrun/integration/utils/__init__.py +112 -0
  61. agentrun/integration/utils/adapter.py +560 -0
  62. agentrun/integration/utils/canonical.py +164 -0
  63. agentrun/integration/utils/converter.py +134 -0
  64. agentrun/integration/utils/model.py +110 -0
  65. agentrun/integration/utils/tool.py +1759 -0
  66. agentrun/knowledgebase/__client_async_template.py +173 -0
  67. agentrun/knowledgebase/__init__.py +53 -0
  68. agentrun/knowledgebase/__knowledgebase_async_template.py +438 -0
  69. agentrun/knowledgebase/api/__data_async_template.py +414 -0
  70. agentrun/knowledgebase/api/__init__.py +19 -0
  71. agentrun/knowledgebase/api/control.py +606 -0
  72. agentrun/knowledgebase/api/data.py +624 -0
  73. agentrun/knowledgebase/client.py +311 -0
  74. agentrun/knowledgebase/knowledgebase.py +748 -0
  75. agentrun/knowledgebase/model.py +270 -0
  76. agentrun/memory_collection/__client_async_template.py +178 -0
  77. agentrun/memory_collection/__init__.py +37 -0
  78. agentrun/memory_collection/__memory_collection_async_template.py +457 -0
  79. agentrun/memory_collection/api/__init__.py +5 -0
  80. agentrun/memory_collection/api/control.py +610 -0
  81. agentrun/memory_collection/client.py +323 -0
  82. agentrun/memory_collection/memory_collection.py +844 -0
  83. agentrun/memory_collection/model.py +162 -0
  84. agentrun/model/__client_async_template.py +357 -0
  85. agentrun/model/__init__.py +57 -0
  86. agentrun/model/__model_proxy_async_template.py +270 -0
  87. agentrun/model/__model_service_async_template.py +267 -0
  88. agentrun/model/api/__init__.py +6 -0
  89. agentrun/model/api/control.py +1173 -0
  90. agentrun/model/api/data.py +196 -0
  91. agentrun/model/client.py +674 -0
  92. agentrun/model/model.py +235 -0
  93. agentrun/model/model_proxy.py +439 -0
  94. agentrun/model/model_service.py +438 -0
  95. agentrun/sandbox/__aio_sandbox_async_template.py +523 -0
  96. agentrun/sandbox/__browser_sandbox_async_template.py +110 -0
  97. agentrun/sandbox/__client_async_template.py +491 -0
  98. agentrun/sandbox/__code_interpreter_sandbox_async_template.py +463 -0
  99. agentrun/sandbox/__init__.py +69 -0
  100. agentrun/sandbox/__sandbox_async_template.py +463 -0
  101. agentrun/sandbox/__template_async_template.py +152 -0
  102. agentrun/sandbox/aio_sandbox.py +912 -0
  103. agentrun/sandbox/api/__aio_data_async_template.py +335 -0
  104. agentrun/sandbox/api/__browser_data_async_template.py +140 -0
  105. agentrun/sandbox/api/__code_interpreter_data_async_template.py +206 -0
  106. agentrun/sandbox/api/__init__.py +19 -0
  107. agentrun/sandbox/api/__sandbox_data_async_template.py +107 -0
  108. agentrun/sandbox/api/aio_data.py +551 -0
  109. agentrun/sandbox/api/browser_data.py +172 -0
  110. agentrun/sandbox/api/code_interpreter_data.py +396 -0
  111. agentrun/sandbox/api/control.py +1051 -0
  112. agentrun/sandbox/api/playwright_async.py +492 -0
  113. agentrun/sandbox/api/playwright_sync.py +492 -0
  114. agentrun/sandbox/api/sandbox_data.py +154 -0
  115. agentrun/sandbox/browser_sandbox.py +185 -0
  116. agentrun/sandbox/client.py +925 -0
  117. agentrun/sandbox/code_interpreter_sandbox.py +823 -0
  118. agentrun/sandbox/model.py +384 -0
  119. agentrun/sandbox/sandbox.py +848 -0
  120. agentrun/sandbox/template.py +217 -0
  121. agentrun/server/__init__.py +191 -0
  122. agentrun/server/agui_normalizer.py +180 -0
  123. agentrun/server/agui_protocol.py +797 -0
  124. agentrun/server/invoker.py +309 -0
  125. agentrun/server/model.py +427 -0
  126. agentrun/server/openai_protocol.py +535 -0
  127. agentrun/server/protocol.py +140 -0
  128. agentrun/server/server.py +208 -0
  129. agentrun/toolset/__client_async_template.py +62 -0
  130. agentrun/toolset/__init__.py +51 -0
  131. agentrun/toolset/__toolset_async_template.py +204 -0
  132. agentrun/toolset/api/__init__.py +17 -0
  133. agentrun/toolset/api/control.py +262 -0
  134. agentrun/toolset/api/mcp.py +100 -0
  135. agentrun/toolset/api/openapi.py +1251 -0
  136. agentrun/toolset/client.py +102 -0
  137. agentrun/toolset/model.py +321 -0
  138. agentrun/toolset/toolset.py +271 -0
  139. agentrun/utils/__data_api_async_template.py +721 -0
  140. agentrun/utils/__init__.py +5 -0
  141. agentrun/utils/__resource_async_template.py +158 -0
  142. agentrun/utils/config.py +270 -0
  143. agentrun/utils/control_api.py +105 -0
  144. agentrun/utils/data_api.py +1121 -0
  145. agentrun/utils/exception.py +151 -0
  146. agentrun/utils/helper.py +108 -0
  147. agentrun/utils/log.py +77 -0
  148. agentrun/utils/model.py +168 -0
  149. agentrun/utils/resource.py +291 -0
  150. agentrun_inner_test-0.0.62.dist-info/METADATA +265 -0
  151. agentrun_inner_test-0.0.62.dist-info/RECORD +154 -0
  152. agentrun_inner_test-0.0.62.dist-info/WHEEL +5 -0
  153. agentrun_inner_test-0.0.62.dist-info/licenses/LICENSE +201 -0
  154. agentrun_inner_test-0.0.62.dist-info/top_level.txt +1 -0
@@ -0,0 +1,624 @@
1
+ """
2
+ This file is auto generated by the code generation script.
3
+ Do not modify this file manually.
4
+ Use the `make codegen` command to regenerate.
5
+
6
+ 当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。
7
+ 使用 `make codegen` 命令重新生成。
8
+
9
+ source: agentrun/knowledgebase/api/__data_async_template.py
10
+
11
+ KnowledgeBase 数据链路 API / KnowledgeBase Data API
12
+
13
+ 提供知识库检索功能的数据链路 API。
14
+ Provides data API for knowledge base retrieval operations.
15
+
16
+ 根据不同的 provider 类型(ragflow / bailian)分发到不同的实现。
17
+ Dispatches to different implementations based on provider type (ragflow / bailian).
18
+ """
19
+
20
+ from abc import ABC, abstractmethod
21
+ from typing import Any, Dict, List, Optional, Union
22
+
23
+ from alibabacloud_bailian20231229 import models as bailian_models
24
+ import httpx
25
+
26
+ from agentrun.utils.config import Config
27
+ from agentrun.utils.control_api import ControlAPI
28
+ from agentrun.utils.data_api import DataAPI, ResourceType
29
+ from agentrun.utils.log import logger
30
+
31
+ from ..model import (
32
+ BailianProviderSettings,
33
+ BailianRetrieveSettings,
34
+ KnowledgeBaseProvider,
35
+ RagFlowProviderSettings,
36
+ RagFlowRetrieveSettings,
37
+ )
38
+
39
+
40
+ class KnowledgeBaseDataAPI(ABC):
41
+ """知识库数据链路 API 基类 / KnowledgeBase Data API Base Class
42
+
43
+ 定义知识库检索的抽象接口,由具体的 provider 实现。
44
+ Defines abstract interface for knowledge base retrieval, implemented by specific providers.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ knowledge_base_name: str,
50
+ config: Optional[Config] = None,
51
+ ):
52
+ """初始化知识库数据链路 API / Initialize KnowledgeBase Data API
53
+
54
+ Args:
55
+ knowledge_base_name: 知识库名称 / Knowledge base name
56
+ config: 配置 / Configuration
57
+ """
58
+ self.knowledge_base_name = knowledge_base_name
59
+ self.config = Config.with_configs(config)
60
+
61
+ @abstractmethod
62
+ async def retrieve_async(
63
+ self,
64
+ query: str,
65
+ config: Optional[Config] = None,
66
+ ) -> Dict[str, Any]:
67
+ """检索知识库(异步)/ Retrieve from knowledge base (async)
68
+
69
+ Args:
70
+ query: 查询文本 / Query text
71
+ config: 配置 / Configuration
72
+
73
+ Returns:
74
+ Dict[str, Any]: 检索结果 / Retrieval results
75
+ """
76
+ raise NotImplementedError("Subclasses must implement retrieve_async")
77
+
78
+ @abstractmethod
79
+ def retrieve(
80
+ self,
81
+ query: str,
82
+ config: Optional[Config] = None,
83
+ ) -> Dict[str, Any]:
84
+ """检索知识库(同步)/ Retrieve from knowledge base (async)
85
+
86
+ Args:
87
+ query: 查询文本 / Query text
88
+ config: 配置 / Configuration
89
+
90
+ Returns:
91
+ Dict[str, Any]: 检索结果 / Retrieval results
92
+ """
93
+ raise NotImplementedError("Subclasses must implement retrieve")
94
+
95
+
96
+ class RagFlowDataAPI(KnowledgeBaseDataAPI):
97
+ """RagFlow 知识库数据链路 API / RagFlow KnowledgeBase Data API
98
+
99
+ 实现 RagFlow 知识库的检索逻辑。
100
+ Implements retrieval logic for RagFlow knowledge base.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ knowledge_base_name: str,
106
+ config: Optional[Config] = None,
107
+ provider_settings: Optional[RagFlowProviderSettings] = None,
108
+ retrieve_settings: Optional[RagFlowRetrieveSettings] = None,
109
+ credential_name: Optional[str] = None,
110
+ ):
111
+ """初始化 RagFlow 知识库数据链路 API / Initialize RagFlow KnowledgeBase Data API
112
+
113
+ Args:
114
+ knowledge_base_name: 知识库名称 / Knowledge base name
115
+ config: 配置 / Configuration
116
+ provider_settings: RagFlow 提供商设置 / RagFlow provider settings
117
+ retrieve_settings: RagFlow 检索设置 / RagFlow retrieve settings
118
+ credential_name: 凭证名称 / Credential name
119
+ """
120
+ super().__init__(knowledge_base_name, config)
121
+ self.provider_settings = provider_settings
122
+ self.retrieve_settings = retrieve_settings
123
+ self.credential_name = credential_name
124
+
125
+ async def _get_api_key_async(self, config: Optional[Config] = None) -> str:
126
+ """获取 API Key(异步)/ Get API Key (async)
127
+
128
+ Args:
129
+ config: 配置 / Configuration
130
+
131
+ Returns:
132
+ str: API Key
133
+
134
+ Raises:
135
+ ValueError: 凭证名称未设置或凭证不存在 / Credential name not set or credential not found
136
+ """
137
+ if not self.credential_name:
138
+ raise ValueError(
139
+ "credential_name is required for RagFlow retrieval"
140
+ )
141
+
142
+ from agentrun.credential import Credential
143
+
144
+ credential = await Credential.get_by_name_async(
145
+ self.credential_name, config=config
146
+ )
147
+ if not credential.credential_secret:
148
+ raise ValueError(
149
+ f"Credential '{self.credential_name}' has no secret configured"
150
+ )
151
+ return credential.credential_secret
152
+
153
+ def _get_api_key(self, config: Optional[Config] = None) -> str:
154
+ """获取 API Key(同步)/ Get API Key (async)
155
+
156
+ Args:
157
+ config: 配置 / Configuration
158
+
159
+ Returns:
160
+ str: API Key
161
+
162
+ Raises:
163
+ ValueError: 凭证名称未设置或凭证不存在 / Credential name not set or credential not found
164
+ """
165
+ if not self.credential_name:
166
+ raise ValueError(
167
+ "credential_name is required for RagFlow retrieval"
168
+ )
169
+
170
+ from agentrun.credential import Credential
171
+
172
+ credential = Credential.get_by_name(self.credential_name, config=config)
173
+ if not credential.credential_secret:
174
+ raise ValueError(
175
+ f"Credential '{self.credential_name}' has no secret configured"
176
+ )
177
+ return credential.credential_secret
178
+
179
+ def _build_request_body(self, query: str) -> Dict[str, Any]:
180
+ """构建请求体 / Build request body
181
+
182
+ Args:
183
+ query: 查询文本 / Query text
184
+
185
+ Returns:
186
+ Dict[str, Any]: 请求体 / Request body
187
+ """
188
+ if self.provider_settings is None:
189
+ raise ValueError(
190
+ "provider_settings is required for RagFlow retrieval"
191
+ )
192
+
193
+ body: Dict[str, Any] = {
194
+ "question": query,
195
+ "dataset_ids": self.provider_settings.dataset_ids,
196
+ "page": 1,
197
+ "page_size": 30,
198
+ }
199
+
200
+ # 添加检索设置 / Add retrieve settings
201
+ if self.retrieve_settings:
202
+ if self.retrieve_settings.similarity_threshold is not None:
203
+ body["similarity_threshold"] = (
204
+ self.retrieve_settings.similarity_threshold
205
+ )
206
+ if self.retrieve_settings.vector_similarity_weight is not None:
207
+ body["vector_similarity_weight"] = (
208
+ self.retrieve_settings.vector_similarity_weight
209
+ )
210
+ if self.retrieve_settings.cross_languages is not None:
211
+ body["cross_languages"] = self.retrieve_settings.cross_languages
212
+
213
+ return body
214
+
215
+ async def retrieve_async(
216
+ self,
217
+ query: str,
218
+ config: Optional[Config] = None,
219
+ ) -> Dict[str, Any]:
220
+ """RagFlow 检索(异步)/ RagFlow retrieval (async)
221
+
222
+ Args:
223
+ query: 查询文本 / Query text
224
+ config: 配置 / Configuration
225
+
226
+ Returns:
227
+ Dict[str, Any]: 检索结果 / Retrieval results
228
+ """
229
+ try:
230
+ if self.provider_settings is None:
231
+ raise ValueError(
232
+ "provider_settings is required for RagFlow retrieval"
233
+ )
234
+
235
+ # 获取 API Key / Get API Key
236
+ api_key = await self._get_api_key_async(config)
237
+
238
+ # 构建请求 / Build request
239
+ base_url = self.provider_settings.base_url.rstrip("/")
240
+ url = f"{base_url}/api/v1/retrieval"
241
+ headers = {
242
+ "Content-Type": "application/json",
243
+ "Authorization": f"Bearer {api_key}",
244
+ }
245
+ body = self._build_request_body(query)
246
+
247
+ # 发送请求 / Send request
248
+ async with httpx.AsyncClient(
249
+ timeout=self.config.get_timeout()
250
+ ) as client:
251
+ response = await client.post(url, json=body, headers=headers)
252
+ response.raise_for_status()
253
+ result = response.json()
254
+ logger.debug(f"RagFlow retrieval result: {result}")
255
+
256
+ # 返回结果 / Return result
257
+ data = result.get("data", {})
258
+
259
+ if data == False:
260
+ raise Exception(f"RagFlow retrieval failed: {result}")
261
+
262
+ return {
263
+ "data": data,
264
+ "query": query,
265
+ "knowledge_base_name": self.knowledge_base_name,
266
+ }
267
+ except Exception as e:
268
+ logger.warning(
269
+ "Failed to retrieve from RagFlow knowledge base "
270
+ f"'{self.knowledge_base_name}': {e}"
271
+ )
272
+ return {
273
+ "data": f"Failed to retrieve: {e}",
274
+ "query": query,
275
+ "knowledge_base_name": self.knowledge_base_name,
276
+ "error": True,
277
+ }
278
+
279
+ def retrieve(
280
+ self,
281
+ query: str,
282
+ config: Optional[Config] = None,
283
+ ) -> Dict[str, Any]:
284
+ """RagFlow 检索(同步)/ RagFlow retrieval (async)
285
+
286
+ Args:
287
+ query: 查询文本 / Query text
288
+ config: 配置 / Configuration
289
+
290
+ Returns:
291
+ Dict[str, Any]: 检索结果 / Retrieval results
292
+ """
293
+ try:
294
+ if self.provider_settings is None:
295
+ raise ValueError(
296
+ "provider_settings is required for RagFlow retrieval"
297
+ )
298
+
299
+ # 获取 API Key / Get API Key
300
+ api_key = self._get_api_key(config)
301
+
302
+ # 构建请求 / Build request
303
+ base_url = self.provider_settings.base_url.rstrip("/")
304
+ url = f"{base_url}/api/v1/retrieval"
305
+ headers = {
306
+ "Content-Type": "application/json",
307
+ "Authorization": f"Bearer {api_key}",
308
+ }
309
+ body = self._build_request_body(query)
310
+
311
+ # 发送请求 / Send request
312
+ with httpx.Client(timeout=self.config.get_timeout()) as client:
313
+ response = client.post(url, json=body, headers=headers)
314
+ response.raise_for_status()
315
+ result = response.json()
316
+ logger.debug(f"RagFlow retrieval result: {result}")
317
+
318
+ # 返回结果 / Return result
319
+ data = result.get("data", {})
320
+
321
+ if data == False:
322
+ raise Exception(f"RagFlow retrieval failed: {result}")
323
+
324
+ return {
325
+ "data": data,
326
+ "query": query,
327
+ "knowledge_base_name": self.knowledge_base_name,
328
+ }
329
+ except Exception as e:
330
+ logger.warning(
331
+ "Failed to retrieve from RagFlow knowledge base "
332
+ f"'{self.knowledge_base_name}': {e}"
333
+ )
334
+ return {
335
+ "data": f"Failed to retrieve: {e}",
336
+ "query": query,
337
+ "knowledge_base_name": self.knowledge_base_name,
338
+ "error": True,
339
+ }
340
+
341
+
342
+ class BailianDataAPI(KnowledgeBaseDataAPI, ControlAPI):
343
+ """百炼知识库数据链路 API / Bailian KnowledgeBase Data API
344
+
345
+ 实现百炼知识库的检索逻辑。
346
+ Implements retrieval logic for Bailian knowledge base.
347
+ """
348
+
349
+ def __init__(
350
+ self,
351
+ knowledge_base_name: str,
352
+ config: Optional[Config] = None,
353
+ provider_settings: Optional[BailianProviderSettings] = None,
354
+ retrieve_settings: Optional[BailianRetrieveSettings] = None,
355
+ ):
356
+ """初始化百炼知识库数据链路 API / Initialize Bailian KnowledgeBase Data API
357
+
358
+ Args:
359
+ knowledge_base_name: 知识库名称 / Knowledge base name
360
+ config: 配置 / Configuration
361
+ provider_settings: 百炼提供商设置 / Bailian provider settings
362
+ retrieve_settings: 百炼检索设置 / Bailian retrieve settings
363
+ """
364
+ KnowledgeBaseDataAPI.__init__(self, knowledge_base_name, config)
365
+ ControlAPI.__init__(self, config)
366
+ self.provider_settings = provider_settings
367
+ self.retrieve_settings = retrieve_settings
368
+
369
+ async def retrieve_async(
370
+ self,
371
+ query: str,
372
+ config: Optional[Config] = None,
373
+ ) -> Dict[str, Any]:
374
+ """百炼检索(异步)/ Bailian retrieval (async)
375
+
376
+ Args:
377
+ query: 查询文本 / Query text
378
+ config: 配置 / Configuration
379
+
380
+ Returns:
381
+ Dict[str, Any]: 检索结果 / Retrieval results
382
+ """
383
+ try:
384
+ if self.provider_settings is None:
385
+ raise ValueError(
386
+ "provider_settings is required for Bailian retrieval"
387
+ )
388
+
389
+ workspace_id = self.provider_settings.workspace_id
390
+ index_ids = self.provider_settings.index_ids
391
+
392
+ # 构建检索请求 / Build retrieve request
393
+ request_params: Dict[str, Any] = {
394
+ "query": query,
395
+ }
396
+
397
+ # 添加检索设置 / Add retrieve settings
398
+ if self.retrieve_settings:
399
+ if self.retrieve_settings.dense_similarity_top_k is not None:
400
+ request_params["dense_similarity_top_k"] = (
401
+ self.retrieve_settings.dense_similarity_top_k
402
+ )
403
+ if self.retrieve_settings.sparse_similarity_top_k is not None:
404
+ request_params["sparse_similarity_top_k"] = (
405
+ self.retrieve_settings.sparse_similarity_top_k
406
+ )
407
+ if self.retrieve_settings.rerank_min_score is not None:
408
+ request_params["rerank_min_score"] = (
409
+ self.retrieve_settings.rerank_min_score
410
+ )
411
+ if self.retrieve_settings.rerank_top_n is not None:
412
+ request_params["rerank_top_n"] = (
413
+ self.retrieve_settings.rerank_top_n
414
+ )
415
+
416
+ # 获取百炼客户端 / Get Bailian client
417
+ client = self._get_bailian_client(config)
418
+
419
+ # 对每个 index_id 进行检索并合并结果 / Retrieve from each index and merge results
420
+ all_nodes: List[Dict[str, Any]] = []
421
+ for index_id in index_ids:
422
+ request_params["index_id"] = index_id
423
+ request = bailian_models.RetrieveRequest(**request_params)
424
+ response = await client.retrieve_async(workspace_id, request)
425
+ logger.debug(f"Bailian retrieve response: {response}")
426
+
427
+ if (
428
+ response.body
429
+ and response.body.data
430
+ and response.body.data.nodes
431
+ ):
432
+ for node in response.body.data.nodes:
433
+ all_nodes.append({
434
+ "text": (
435
+ node.text if hasattr(node, "text") else None
436
+ ),
437
+ "score": (
438
+ node.score if hasattr(node, "score") else None
439
+ ),
440
+ "metadata": (
441
+ node.metadata
442
+ if hasattr(node, "metadata")
443
+ else None
444
+ ),
445
+ })
446
+
447
+ return {
448
+ "data": all_nodes,
449
+ "query": query,
450
+ "knowledge_base_name": self.knowledge_base_name,
451
+ }
452
+ except Exception as e:
453
+ logger.warning(
454
+ "Failed to retrieve from Bailian knowledge base "
455
+ f"'{self.knowledge_base_name}': {e}"
456
+ )
457
+ return {
458
+ "data": f"Failed to retrieve: {e}",
459
+ "query": query,
460
+ "knowledge_base_name": self.knowledge_base_name,
461
+ "error": True,
462
+ }
463
+
464
+ def retrieve(
465
+ self,
466
+ query: str,
467
+ config: Optional[Config] = None,
468
+ ) -> Dict[str, Any]:
469
+ """百炼检索(同步)/ Bailian retrieval (async)
470
+
471
+ Args:
472
+ query: 查询文本 / Query text
473
+ config: 配置 / Configuration
474
+
475
+ Returns:
476
+ Dict[str, Any]: 检索结果 / Retrieval results
477
+ """
478
+ try:
479
+ if self.provider_settings is None:
480
+ raise ValueError(
481
+ "provider_settings is required for Bailian retrieval"
482
+ )
483
+
484
+ workspace_id = self.provider_settings.workspace_id
485
+ index_ids = self.provider_settings.index_ids
486
+
487
+ # 构建检索请求 / Build retrieve request
488
+ request_params: Dict[str, Any] = {
489
+ "query": query,
490
+ }
491
+
492
+ # 添加检索设置 / Add retrieve settings
493
+ if self.retrieve_settings:
494
+ if self.retrieve_settings.dense_similarity_top_k is not None:
495
+ request_params["dense_similarity_top_k"] = (
496
+ self.retrieve_settings.dense_similarity_top_k
497
+ )
498
+ if self.retrieve_settings.sparse_similarity_top_k is not None:
499
+ request_params["sparse_similarity_top_k"] = (
500
+ self.retrieve_settings.sparse_similarity_top_k
501
+ )
502
+ if self.retrieve_settings.rerank_min_score is not None:
503
+ request_params["rerank_min_score"] = (
504
+ self.retrieve_settings.rerank_min_score
505
+ )
506
+ if self.retrieve_settings.rerank_top_n is not None:
507
+ request_params["rerank_top_n"] = (
508
+ self.retrieve_settings.rerank_top_n
509
+ )
510
+
511
+ # 获取百炼客户端 / Get Bailian client
512
+ client = self._get_bailian_client(config)
513
+
514
+ # 对每个 index_id 进行检索并合并结果 / Retrieve from each index and merge results
515
+ all_nodes: List[Dict[str, Any]] = []
516
+ for index_id in index_ids:
517
+ request_params["index_id"] = index_id
518
+ request = bailian_models.RetrieveRequest(**request_params)
519
+ response = client.retrieve(workspace_id, request)
520
+ logger.debug(f"Bailian retrieve response: {response}")
521
+
522
+ if (
523
+ response.body
524
+ and response.body.data
525
+ and response.body.data.nodes
526
+ ):
527
+ for node in response.body.data.nodes:
528
+ all_nodes.append({
529
+ "text": (
530
+ node.text if hasattr(node, "text") else None
531
+ ),
532
+ "score": (
533
+ node.score if hasattr(node, "score") else None
534
+ ),
535
+ "metadata": (
536
+ node.metadata
537
+ if hasattr(node, "metadata")
538
+ else None
539
+ ),
540
+ })
541
+
542
+ return {
543
+ "data": all_nodes,
544
+ "query": query,
545
+ "knowledge_base_name": self.knowledge_base_name,
546
+ }
547
+ except Exception as e:
548
+ logger.warning(
549
+ "Failed to retrieve from Bailian knowledge base "
550
+ f"'{self.knowledge_base_name}': {e}"
551
+ )
552
+ return {
553
+ "data": f"Failed to retrieve: {e}",
554
+ "query": query,
555
+ "knowledge_base_name": self.knowledge_base_name,
556
+ "error": True,
557
+ }
558
+
559
+
560
+ def get_data_api(
561
+ provider: KnowledgeBaseProvider,
562
+ knowledge_base_name: str,
563
+ config: Optional[Config] = None,
564
+ provider_settings: Optional[
565
+ Union[RagFlowProviderSettings, BailianProviderSettings]
566
+ ] = None,
567
+ retrieve_settings: Optional[
568
+ Union[RagFlowRetrieveSettings, BailianRetrieveSettings]
569
+ ] = None,
570
+ credential_name: Optional[str] = None,
571
+ ) -> KnowledgeBaseDataAPI:
572
+ """根据 provider 类型获取对应的数据链路 API / Get data API by provider type
573
+
574
+ Args:
575
+ provider: 提供商类型 / Provider type
576
+ knowledge_base_name: 知识库名称 / Knowledge base name
577
+ config: 配置 / Configuration
578
+ provider_settings: 提供商设置 / Provider settings
579
+ retrieve_settings: 检索设置 / Retrieve settings
580
+ credential_name: 凭证名称(RagFlow 需要)/ Credential name (required for RagFlow)
581
+
582
+ Returns:
583
+ KnowledgeBaseDataAPI: 对应的数据链路 API 实例 / Corresponding data API instance
584
+
585
+ Raises:
586
+ ValueError: 不支持的 provider 类型 / Unsupported provider type
587
+ """
588
+ if provider == KnowledgeBaseProvider.RAGFLOW or provider == "ragflow":
589
+ ragflow_provider_settings = (
590
+ provider_settings
591
+ if isinstance(provider_settings, RagFlowProviderSettings)
592
+ else None
593
+ )
594
+ ragflow_retrieve_settings = (
595
+ retrieve_settings
596
+ if isinstance(retrieve_settings, RagFlowRetrieveSettings)
597
+ else None
598
+ )
599
+ return RagFlowDataAPI(
600
+ knowledge_base_name,
601
+ config,
602
+ provider_settings=ragflow_provider_settings,
603
+ retrieve_settings=ragflow_retrieve_settings,
604
+ credential_name=credential_name,
605
+ )
606
+ elif provider == KnowledgeBaseProvider.BAILIAN or provider == "bailian":
607
+ bailian_provider_settings = (
608
+ provider_settings
609
+ if isinstance(provider_settings, BailianProviderSettings)
610
+ else None
611
+ )
612
+ bailian_retrieve_settings = (
613
+ retrieve_settings
614
+ if isinstance(retrieve_settings, BailianRetrieveSettings)
615
+ else None
616
+ )
617
+ return BailianDataAPI(
618
+ knowledge_base_name,
619
+ config,
620
+ provider_settings=bailian_provider_settings,
621
+ retrieve_settings=bailian_retrieve_settings,
622
+ )
623
+ else:
624
+ raise ValueError(f"Unsupported provider type: {provider}")