fiuai-sdk-python 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fiuai_sdk_python/resp.py +51 -3
- {fiuai_sdk_python-0.7.0.dist-info → fiuai_sdk_python-0.7.2.dist-info}/METADATA +1 -1
- {fiuai_sdk_python-0.7.0.dist-info → fiuai_sdk_python-0.7.2.dist-info}/RECORD +5 -12
- fiuai_sdk_python/pkg/db/__init__.py +0 -35
- fiuai_sdk_python/pkg/db/config.py +0 -25
- fiuai_sdk_python/pkg/db/errors.py +0 -27
- fiuai_sdk_python/pkg/db/manager.py +0 -439
- fiuai_sdk_python/pkg/db/utils.py +0 -78
- fiuai_sdk_python/pkg/vector/__init__.py +0 -23
- fiuai_sdk_python/pkg/vector/vector.py +0 -853
- {fiuai_sdk_python-0.7.0.dist-info → fiuai_sdk_python-0.7.2.dist-info}/WHEEL +0 -0
- {fiuai_sdk_python-0.7.0.dist-info → fiuai_sdk_python-0.7.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,853 +0,0 @@
|
|
|
1
|
-
# -- coding: utf-8 --
|
|
2
|
-
# Project: fiuai-world
|
|
3
|
-
# Created Date: 2025-01-27
|
|
4
|
-
# Author: liming
|
|
5
|
-
# Email: lmlala@aliyun.com
|
|
6
|
-
# Copyright (c) 2025 FiuAI
|
|
7
|
-
|
|
8
|
-
import os
|
|
9
|
-
import time
|
|
10
|
-
import threading
|
|
11
|
-
from typing import Optional, Dict, Any, List, Callable, Union
|
|
12
|
-
from dataclasses import dataclass
|
|
13
|
-
|
|
14
|
-
from qdrant_client import QdrantClient
|
|
15
|
-
from qdrant_client.models import (
|
|
16
|
-
Distance,
|
|
17
|
-
VectorParams,
|
|
18
|
-
PointStruct,
|
|
19
|
-
Filter,
|
|
20
|
-
FieldCondition,
|
|
21
|
-
MatchValue,
|
|
22
|
-
CollectionStatus,
|
|
23
|
-
UpdateStatus,
|
|
24
|
-
)
|
|
25
|
-
from qdrant_client.http import exceptions as qdrant_exceptions
|
|
26
|
-
|
|
27
|
-
from utils import get_logger
|
|
28
|
-
from utils.errors import FiuaiBaseError
|
|
29
|
-
|
|
30
|
-
logger = get_logger(__name__)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class QdrantError(FiuaiBaseError):
|
|
34
|
-
"""Qdrant 相关错误"""
|
|
35
|
-
pass
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
@dataclass
|
|
39
|
-
class QdrantConfig:
|
|
40
|
-
"""Qdrant 配置"""
|
|
41
|
-
host: str
|
|
42
|
-
port: int = 6333
|
|
43
|
-
api_key: Optional[str] = None
|
|
44
|
-
timeout: int = 30
|
|
45
|
-
retry_count: int = 3
|
|
46
|
-
retry_delay: int = 1
|
|
47
|
-
prefer_grpc: bool = False
|
|
48
|
-
https: bool = False
|
|
49
|
-
url: Optional[str] = None # 如果提供 url,则优先使用 url
|
|
50
|
-
|
|
51
|
-
@classmethod
|
|
52
|
-
def from_env(cls, host: Optional[str] = None, port: Optional[int] = None) -> 'QdrantConfig':
|
|
53
|
-
"""从环境变量创建配置
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
host: Qdrant 主机,如果为 None 则从环境变量 QDRANT_HOST 读取,默认为 127.0.0.1
|
|
57
|
-
port: Qdrant 端口,如果为 None 则从环境变量 QDRANT_PORT 读取,默认为 6333
|
|
58
|
-
|
|
59
|
-
Returns:
|
|
60
|
-
QdrantConfig: 配置对象
|
|
61
|
-
"""
|
|
62
|
-
return cls(
|
|
63
|
-
host=host or os.getenv('QDRANT_HOST', '127.0.0.1'),
|
|
64
|
-
port=port or int(os.getenv('QDRANT_PORT', '6333')),
|
|
65
|
-
api_key=os.getenv('QDRANT_API_KEY'),
|
|
66
|
-
timeout=int(os.getenv('QDRANT_TIMEOUT', '30')),
|
|
67
|
-
retry_count=int(os.getenv('QDRANT_RETRY_COUNT', '3')),
|
|
68
|
-
retry_delay=int(os.getenv('QDRANT_RETRY_DELAY', '1')),
|
|
69
|
-
prefer_grpc=os.getenv('QDRANT_PREFER_GRPC', 'false').lower() == 'true',
|
|
70
|
-
https=os.getenv('QDRANT_HTTPS', 'false').lower() == 'true',
|
|
71
|
-
url=os.getenv('QDRANT_URL'),
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
class QdrantManager:
|
|
76
|
-
"""Qdrant 客户端管理器单例类
|
|
77
|
-
|
|
78
|
-
提供连接管理、重连、重试等稳定性机制
|
|
79
|
-
提供基础 CRUD 接口和 hook 机制
|
|
80
|
-
"""
|
|
81
|
-
_instance: Optional['QdrantManager'] = None
|
|
82
|
-
_lock = threading.Lock()
|
|
83
|
-
_initialized = False
|
|
84
|
-
|
|
85
|
-
def __new__(cls):
|
|
86
|
-
if cls._instance is None:
|
|
87
|
-
with cls._lock:
|
|
88
|
-
if cls._instance is None:
|
|
89
|
-
cls._instance = super().__new__(cls)
|
|
90
|
-
return cls._instance
|
|
91
|
-
|
|
92
|
-
def __init__(self):
|
|
93
|
-
if hasattr(self, '_initialized') and self._initialized:
|
|
94
|
-
return
|
|
95
|
-
|
|
96
|
-
self._config: Optional[QdrantConfig] = None
|
|
97
|
-
self._client: Optional[QdrantClient] = None
|
|
98
|
-
self._is_connected = False
|
|
99
|
-
self._last_check_time = 0
|
|
100
|
-
self._check_interval = 30 # 连接检查间隔(秒)
|
|
101
|
-
self._hooks: Dict[str, List[Callable]] = {
|
|
102
|
-
'before_query': [],
|
|
103
|
-
'after_query': [],
|
|
104
|
-
'before_create_collection': [],
|
|
105
|
-
'after_create_collection': [],
|
|
106
|
-
'before_delete_collection': [],
|
|
107
|
-
'after_delete_collection': [],
|
|
108
|
-
}
|
|
109
|
-
self._initialized = True
|
|
110
|
-
|
|
111
|
-
def initialize(self, config: QdrantConfig):
|
|
112
|
-
"""初始化 Qdrant 客户端
|
|
113
|
-
|
|
114
|
-
Args:
|
|
115
|
-
config: Qdrant 配置
|
|
116
|
-
|
|
117
|
-
Raises:
|
|
118
|
-
QdrantError: 初始化失败时抛出
|
|
119
|
-
"""
|
|
120
|
-
if self._is_connected:
|
|
121
|
-
logger.warning("Qdrant 已经初始化,跳过重复初始化")
|
|
122
|
-
return
|
|
123
|
-
|
|
124
|
-
self._config = config
|
|
125
|
-
|
|
126
|
-
try:
|
|
127
|
-
# 临时清除代理环境变量,强制禁用代理
|
|
128
|
-
proxy_env_vars = [
|
|
129
|
-
'HTTP_PROXY', 'HTTPS_PROXY', 'http_proxy', 'https_proxy',
|
|
130
|
-
'ALL_PROXY', 'all_proxy', 'SOCKS_PROXY', 'socks_proxy',
|
|
131
|
-
'NO_PROXY', 'no_proxy'
|
|
132
|
-
]
|
|
133
|
-
saved_proxy_vars = {}
|
|
134
|
-
for var in proxy_env_vars:
|
|
135
|
-
if var in os.environ:
|
|
136
|
-
saved_proxy_vars[var] = os.environ.pop(var)
|
|
137
|
-
|
|
138
|
-
try:
|
|
139
|
-
# 创建客户端
|
|
140
|
-
if config.url:
|
|
141
|
-
# 使用 URL 连接
|
|
142
|
-
self._client = QdrantClient(
|
|
143
|
-
url=config.url,
|
|
144
|
-
api_key=config.api_key,
|
|
145
|
-
timeout=config.timeout,
|
|
146
|
-
prefer_grpc=config.prefer_grpc,
|
|
147
|
-
)
|
|
148
|
-
else:
|
|
149
|
-
# 使用 host:port 连接
|
|
150
|
-
self._client = QdrantClient(
|
|
151
|
-
host=config.host,
|
|
152
|
-
port=config.port,
|
|
153
|
-
api_key=config.api_key,
|
|
154
|
-
timeout=config.timeout,
|
|
155
|
-
prefer_grpc=config.prefer_grpc,
|
|
156
|
-
https=config.https,
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
# 测试连接
|
|
160
|
-
self._client.get_collections()
|
|
161
|
-
|
|
162
|
-
self._is_connected = True
|
|
163
|
-
self._last_check_time = time.time()
|
|
164
|
-
|
|
165
|
-
logger.info(f"Qdrant 客户端初始化成功: {config.url or f'{config.host}:{config.port}'}")
|
|
166
|
-
finally:
|
|
167
|
-
# 恢复代理环境变量
|
|
168
|
-
for var, value in saved_proxy_vars.items():
|
|
169
|
-
os.environ[var] = value
|
|
170
|
-
|
|
171
|
-
except qdrant_exceptions.UnexpectedResponse as e:
|
|
172
|
-
self._is_connected = False
|
|
173
|
-
self._client = None
|
|
174
|
-
logger.error(f"Qdrant 初始化失败 - 响应错误: {str(e)}")
|
|
175
|
-
raise QdrantError(f"初始化失败 - 响应错误: {str(e)}")
|
|
176
|
-
except qdrant_exceptions.ResponseHandlingException as e:
|
|
177
|
-
self._is_connected = False
|
|
178
|
-
self._client = None
|
|
179
|
-
logger.error(f"Qdrant 初始化失败 - 响应处理错误: {str(e)}")
|
|
180
|
-
raise QdrantError(f"初始化失败 - 响应处理错误: {str(e)}")
|
|
181
|
-
except Exception as e:
|
|
182
|
-
self._is_connected = False
|
|
183
|
-
self._client = None
|
|
184
|
-
logger.error(f"Qdrant 初始化失败: {str(e)}")
|
|
185
|
-
raise QdrantError(f"初始化失败: {str(e)}")
|
|
186
|
-
|
|
187
|
-
def _check_connection(self) -> bool:
|
|
188
|
-
"""检查连接是否有效
|
|
189
|
-
|
|
190
|
-
Returns:
|
|
191
|
-
bool: 连接是否有效
|
|
192
|
-
"""
|
|
193
|
-
if not self._client or not self._is_connected:
|
|
194
|
-
return False
|
|
195
|
-
|
|
196
|
-
# 避免频繁检查
|
|
197
|
-
current_time = time.time()
|
|
198
|
-
if current_time - self._last_check_time < self._check_interval:
|
|
199
|
-
return True
|
|
200
|
-
|
|
201
|
-
try:
|
|
202
|
-
# 执行简单查询测试连接
|
|
203
|
-
self._client.get_collections()
|
|
204
|
-
self._last_check_time = current_time
|
|
205
|
-
return True
|
|
206
|
-
|
|
207
|
-
except Exception as e:
|
|
208
|
-
logger.warning(f"连接检查失败: {str(e)}")
|
|
209
|
-
self._is_connected = False
|
|
210
|
-
return False
|
|
211
|
-
|
|
212
|
-
def _reconnect(self) -> bool:
|
|
213
|
-
"""重新连接
|
|
214
|
-
|
|
215
|
-
Returns:
|
|
216
|
-
bool: 重连是否成功
|
|
217
|
-
"""
|
|
218
|
-
if not self._config:
|
|
219
|
-
return False
|
|
220
|
-
|
|
221
|
-
logger.info("尝试重新连接 Qdrant...")
|
|
222
|
-
|
|
223
|
-
try:
|
|
224
|
-
# 关闭旧连接
|
|
225
|
-
if self._client:
|
|
226
|
-
self._client = None
|
|
227
|
-
|
|
228
|
-
# 重新初始化
|
|
229
|
-
self.initialize(self._config)
|
|
230
|
-
return True
|
|
231
|
-
|
|
232
|
-
except Exception as e:
|
|
233
|
-
logger.error(f"重连失败: {str(e)}")
|
|
234
|
-
return False
|
|
235
|
-
|
|
236
|
-
def _get_client(self) -> QdrantClient:
|
|
237
|
-
"""获取客户端(带重连机制)
|
|
238
|
-
|
|
239
|
-
Returns:
|
|
240
|
-
QdrantClient: Qdrant 客户端对象
|
|
241
|
-
|
|
242
|
-
Raises:
|
|
243
|
-
QdrantError: 获取客户端失败时抛出
|
|
244
|
-
"""
|
|
245
|
-
if not self._check_connection():
|
|
246
|
-
if not self._reconnect():
|
|
247
|
-
raise QdrantError("无法连接到 Qdrant,请检查配置和网络")
|
|
248
|
-
|
|
249
|
-
return self._client
|
|
250
|
-
|
|
251
|
-
def _execute_with_retry(self, func: Callable, *args, retry_count: Optional[int] = None, **kwargs) -> Any:
|
|
252
|
-
"""执行操作(带重试机制)
|
|
253
|
-
|
|
254
|
-
Args:
|
|
255
|
-
func: 要执行的函数
|
|
256
|
-
*args: 函数位置参数
|
|
257
|
-
retry_count: 重试次数,None 则使用配置中的值
|
|
258
|
-
**kwargs: 函数关键字参数
|
|
259
|
-
|
|
260
|
-
Returns:
|
|
261
|
-
Any: 函数执行结果
|
|
262
|
-
|
|
263
|
-
Raises:
|
|
264
|
-
QdrantError: 执行失败时抛出
|
|
265
|
-
"""
|
|
266
|
-
if retry_count is None:
|
|
267
|
-
retry_count = self._config.retry_count if self._config else 3
|
|
268
|
-
|
|
269
|
-
last_error = None
|
|
270
|
-
for attempt in range(retry_count + 1):
|
|
271
|
-
try:
|
|
272
|
-
client = self._get_client()
|
|
273
|
-
result = func(client, *args, **kwargs)
|
|
274
|
-
return result
|
|
275
|
-
|
|
276
|
-
except qdrant_exceptions.UnexpectedResponse as e:
|
|
277
|
-
last_error = QdrantError(f"响应错误: {str(e)}")
|
|
278
|
-
if attempt < retry_count:
|
|
279
|
-
time.sleep(self._config.retry_delay if self._config else 1)
|
|
280
|
-
self._is_connected = False
|
|
281
|
-
continue
|
|
282
|
-
|
|
283
|
-
except qdrant_exceptions.ResponseHandlingException as e:
|
|
284
|
-
last_error = QdrantError(f"响应处理错误: {str(e)}")
|
|
285
|
-
if attempt < retry_count:
|
|
286
|
-
time.sleep(self._config.retry_delay if self._config else 1)
|
|
287
|
-
self._is_connected = False
|
|
288
|
-
continue
|
|
289
|
-
|
|
290
|
-
except Exception as e:
|
|
291
|
-
last_error = QdrantError(f"执行操作时发生错误: {str(e)}")
|
|
292
|
-
# 如果是连接相关错误,尝试重连
|
|
293
|
-
error_msg = str(e).lower()
|
|
294
|
-
if "connection" in error_msg or "timeout" in error_msg or "network" in error_msg:
|
|
295
|
-
self._is_connected = False
|
|
296
|
-
if attempt < retry_count:
|
|
297
|
-
time.sleep(self._config.retry_delay if self._config else 1)
|
|
298
|
-
continue
|
|
299
|
-
elif attempt < retry_count:
|
|
300
|
-
time.sleep(self._config.retry_delay if self._config else 1)
|
|
301
|
-
continue
|
|
302
|
-
|
|
303
|
-
raise last_error or QdrantError("操作执行失败")
|
|
304
|
-
|
|
305
|
-
def register_hook(self, event: str, hook: Callable):
|
|
306
|
-
"""注册 hook
|
|
307
|
-
|
|
308
|
-
Args:
|
|
309
|
-
event: 事件名称
|
|
310
|
-
hook: hook 函数
|
|
311
|
-
"""
|
|
312
|
-
if event not in self._hooks:
|
|
313
|
-
raise ValueError(f"未知的事件类型: {event}")
|
|
314
|
-
|
|
315
|
-
self._hooks[event].append(hook)
|
|
316
|
-
logger.debug(f"注册 hook: {event}")
|
|
317
|
-
|
|
318
|
-
def unregister_hook(self, event: str, hook: Callable):
|
|
319
|
-
"""取消注册 hook
|
|
320
|
-
|
|
321
|
-
Args:
|
|
322
|
-
event: 事件名称
|
|
323
|
-
hook: hook 函数
|
|
324
|
-
"""
|
|
325
|
-
if event in self._hooks and hook in self._hooks[event]:
|
|
326
|
-
self._hooks[event].remove(hook)
|
|
327
|
-
logger.debug(f"取消注册 hook: {event}")
|
|
328
|
-
|
|
329
|
-
def get_collections(self) -> List[str]:
|
|
330
|
-
"""获取所有集合名称
|
|
331
|
-
|
|
332
|
-
Returns:
|
|
333
|
-
List[str]: 集合名称列表
|
|
334
|
-
"""
|
|
335
|
-
if not self._is_connected:
|
|
336
|
-
raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
|
|
337
|
-
|
|
338
|
-
def _get_collections(client):
|
|
339
|
-
collections = client.get_collections()
|
|
340
|
-
return [col.name for col in collections.collections]
|
|
341
|
-
|
|
342
|
-
return self._execute_with_retry(_get_collections)
|
|
343
|
-
|
|
344
|
-
def create_collection(
|
|
345
|
-
self,
|
|
346
|
-
collection_name: str,
|
|
347
|
-
vectors_config: Union[VectorParams, Dict[str, Any]],
|
|
348
|
-
**kwargs
|
|
349
|
-
) -> bool:
|
|
350
|
-
"""创建集合
|
|
351
|
-
|
|
352
|
-
Args:
|
|
353
|
-
collection_name: 集合名称
|
|
354
|
-
vectors_config: 向量配置
|
|
355
|
-
**kwargs: 其他参数
|
|
356
|
-
|
|
357
|
-
Returns:
|
|
358
|
-
bool: 是否创建成功
|
|
359
|
-
"""
|
|
360
|
-
if not self._is_connected:
|
|
361
|
-
raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
|
|
362
|
-
|
|
363
|
-
# 执行 hook
|
|
364
|
-
for hook in self._hooks.get('before_create_collection', []):
|
|
365
|
-
hook(collection_name, vectors_config, **kwargs)
|
|
366
|
-
|
|
367
|
-
def _create_collection(client, name, config, **kw):
|
|
368
|
-
client.create_collection(
|
|
369
|
-
collection_name=name,
|
|
370
|
-
vectors_config=config,
|
|
371
|
-
**kw
|
|
372
|
-
)
|
|
373
|
-
return True
|
|
374
|
-
|
|
375
|
-
result = self._execute_with_retry(_create_collection, collection_name, vectors_config, **kwargs)
|
|
376
|
-
|
|
377
|
-
# 执行 hook
|
|
378
|
-
for hook in self._hooks.get('after_create_collection', []):
|
|
379
|
-
hook(collection_name, result)
|
|
380
|
-
|
|
381
|
-
return result
|
|
382
|
-
|
|
383
|
-
def delete_collection(self, collection_name: str) -> bool:
|
|
384
|
-
"""删除集合
|
|
385
|
-
|
|
386
|
-
Args:
|
|
387
|
-
collection_name: 集合名称
|
|
388
|
-
|
|
389
|
-
Returns:
|
|
390
|
-
bool: 是否删除成功
|
|
391
|
-
"""
|
|
392
|
-
if not self._is_connected:
|
|
393
|
-
raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
|
|
394
|
-
|
|
395
|
-
# 执行 hook
|
|
396
|
-
for hook in self._hooks.get('before_delete_collection', []):
|
|
397
|
-
hook(collection_name)
|
|
398
|
-
|
|
399
|
-
def _delete_collection(client, name):
|
|
400
|
-
client.delete_collection(collection_name=name)
|
|
401
|
-
return True
|
|
402
|
-
|
|
403
|
-
result = self._execute_with_retry(_delete_collection, collection_name)
|
|
404
|
-
|
|
405
|
-
# 执行 hook
|
|
406
|
-
for hook in self._hooks.get('after_delete_collection', []):
|
|
407
|
-
hook(collection_name, result)
|
|
408
|
-
|
|
409
|
-
return result
|
|
410
|
-
|
|
411
|
-
def collection_exists(self, collection_name: str) -> bool:
|
|
412
|
-
"""检查集合是否存在
|
|
413
|
-
|
|
414
|
-
Args:
|
|
415
|
-
collection_name: 集合名称
|
|
416
|
-
|
|
417
|
-
Returns:
|
|
418
|
-
bool: 集合是否存在
|
|
419
|
-
"""
|
|
420
|
-
if not self._is_connected:
|
|
421
|
-
raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
|
|
422
|
-
|
|
423
|
-
def _collection_exists(client, name):
|
|
424
|
-
try:
|
|
425
|
-
client.get_collection(name)
|
|
426
|
-
return True
|
|
427
|
-
except qdrant_exceptions.UnexpectedResponse:
|
|
428
|
-
return False
|
|
429
|
-
|
|
430
|
-
return self._execute_with_retry(_collection_exists, collection_name)
|
|
431
|
-
|
|
432
|
-
def upsert_points(
|
|
433
|
-
self,
|
|
434
|
-
collection_name: str,
|
|
435
|
-
points: List[PointStruct],
|
|
436
|
-
**kwargs
|
|
437
|
-
) -> UpdateStatus:
|
|
438
|
-
"""插入或更新点
|
|
439
|
-
|
|
440
|
-
Args:
|
|
441
|
-
collection_name: 集合名称
|
|
442
|
-
points: 点列表
|
|
443
|
-
**kwargs: 其他参数
|
|
444
|
-
|
|
445
|
-
Returns:
|
|
446
|
-
UpdateStatus: 更新状态
|
|
447
|
-
"""
|
|
448
|
-
if not self._is_connected:
|
|
449
|
-
raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
|
|
450
|
-
|
|
451
|
-
def _upsert_points(client, name, pts, **kw):
|
|
452
|
-
return client.upsert(
|
|
453
|
-
collection_name=name,
|
|
454
|
-
points=pts,
|
|
455
|
-
**kw
|
|
456
|
-
)
|
|
457
|
-
|
|
458
|
-
return self._execute_with_retry(_upsert_points, collection_name, points, **kwargs)
|
|
459
|
-
|
|
460
|
-
def search_points(
|
|
461
|
-
self,
|
|
462
|
-
collection_name: str,
|
|
463
|
-
query_vector: Union[List[float], str],
|
|
464
|
-
limit: int = 10,
|
|
465
|
-
score_threshold: Optional[float] = None,
|
|
466
|
-
filter: Optional[Filter] = None,
|
|
467
|
-
**kwargs
|
|
468
|
-
) -> List[Any]:
|
|
469
|
-
"""搜索点
|
|
470
|
-
|
|
471
|
-
Args:
|
|
472
|
-
collection_name: 集合名称
|
|
473
|
-
query_vector: 查询向量或命名向量
|
|
474
|
-
limit: 返回结果数量
|
|
475
|
-
score_threshold: 分数阈值
|
|
476
|
-
filter: 过滤条件
|
|
477
|
-
**kwargs: 其他参数
|
|
478
|
-
|
|
479
|
-
Returns:
|
|
480
|
-
List[Any]: 搜索结果列表
|
|
481
|
-
"""
|
|
482
|
-
if not self._is_connected:
|
|
483
|
-
raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
|
|
484
|
-
|
|
485
|
-
# 执行 hook
|
|
486
|
-
for hook in self._hooks.get('before_query', []):
|
|
487
|
-
query_vector = hook(collection_name, query_vector) or query_vector
|
|
488
|
-
|
|
489
|
-
def _search_points(client, name, qv, lim, st, flt, **kw):
|
|
490
|
-
return client.search(
|
|
491
|
-
collection_name=name,
|
|
492
|
-
query_vector=qv,
|
|
493
|
-
limit=lim,
|
|
494
|
-
score_threshold=st,
|
|
495
|
-
query_filter=flt,
|
|
496
|
-
**kw
|
|
497
|
-
)
|
|
498
|
-
|
|
499
|
-
result = self._execute_with_retry(
|
|
500
|
-
_search_points,
|
|
501
|
-
collection_name,
|
|
502
|
-
query_vector,
|
|
503
|
-
limit,
|
|
504
|
-
score_threshold,
|
|
505
|
-
filter,
|
|
506
|
-
**kwargs
|
|
507
|
-
)
|
|
508
|
-
|
|
509
|
-
# 执行 hook
|
|
510
|
-
for hook in self._hooks.get('after_query', []):
|
|
511
|
-
hook(collection_name, query_vector, result)
|
|
512
|
-
|
|
513
|
-
return result
|
|
514
|
-
|
|
515
|
-
def delete_points(
|
|
516
|
-
self,
|
|
517
|
-
collection_name: str,
|
|
518
|
-
points_selector: Union[List[int], Filter],
|
|
519
|
-
**kwargs
|
|
520
|
-
) -> UpdateStatus:
|
|
521
|
-
"""删除点
|
|
522
|
-
|
|
523
|
-
Args:
|
|
524
|
-
collection_name: 集合名称
|
|
525
|
-
points_selector: 点选择器(ID列表或过滤条件)
|
|
526
|
-
**kwargs: 其他参数
|
|
527
|
-
|
|
528
|
-
Returns:
|
|
529
|
-
UpdateStatus: 更新状态
|
|
530
|
-
"""
|
|
531
|
-
if not self._is_connected:
|
|
532
|
-
raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
|
|
533
|
-
|
|
534
|
-
def _delete_points(client, name, selector, **kw):
|
|
535
|
-
return client.delete(
|
|
536
|
-
collection_name=name,
|
|
537
|
-
points_selector=selector,
|
|
538
|
-
**kw
|
|
539
|
-
)
|
|
540
|
-
|
|
541
|
-
return self._execute_with_retry(_delete_points, collection_name, points_selector, **kwargs)
|
|
542
|
-
|
|
543
|
-
def get_point(self, collection_name: str, point_id: Union[int, str], **kwargs) -> Optional[Any]:
|
|
544
|
-
"""获取单个点
|
|
545
|
-
|
|
546
|
-
Args:
|
|
547
|
-
collection_name: 集合名称
|
|
548
|
-
point_id: 点ID
|
|
549
|
-
**kwargs: 其他参数
|
|
550
|
-
|
|
551
|
-
Returns:
|
|
552
|
-
Optional[Any]: 点对象,不存在则返回 None
|
|
553
|
-
"""
|
|
554
|
-
if not self._is_connected:
|
|
555
|
-
raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
|
|
556
|
-
|
|
557
|
-
def _get_point(client, name, pid, **kw):
|
|
558
|
-
result = client.retrieve(
|
|
559
|
-
collection_name=name,
|
|
560
|
-
ids=[pid],
|
|
561
|
-
**kw
|
|
562
|
-
)
|
|
563
|
-
return result[0] if result else None
|
|
564
|
-
|
|
565
|
-
return self._execute_with_retry(_get_point, collection_name, point_id, **kwargs)
|
|
566
|
-
|
|
567
|
-
def close(self):
|
|
568
|
-
"""关闭客户端连接"""
|
|
569
|
-
if self._client:
|
|
570
|
-
# QdrantClient 没有显式的 close 方法,设置为 None 即可
|
|
571
|
-
self._client = None
|
|
572
|
-
self._is_connected = False
|
|
573
|
-
logger.info("Qdrant 客户端已关闭")
|
|
574
|
-
|
|
575
|
-
def is_connected(self) -> bool:
|
|
576
|
-
"""检查是否已连接
|
|
577
|
-
|
|
578
|
-
Returns:
|
|
579
|
-
bool: 是否已连接
|
|
580
|
-
"""
|
|
581
|
-
return self._is_connected and self._check_connection()
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
# 创建全局单例实例
|
|
585
|
-
qdrant_manager = QdrantManager()
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
class ContextAwareQdrantManager:
|
|
589
|
-
"""支持从 context 自动获取 tenant 和 company 的 Qdrant 管理器包装类
|
|
590
|
-
|
|
591
|
-
在查询和更新操作时自动添加 tenant 和 company 过滤条件
|
|
592
|
-
支持手动指定 tenant 和 company(覆盖 context 中的值)
|
|
593
|
-
"""
|
|
594
|
-
|
|
595
|
-
def __init__(
|
|
596
|
-
self,
|
|
597
|
-
manager: QdrantManager,
|
|
598
|
-
auth_tenant_id: Optional[str] = None,
|
|
599
|
-
auth_company_id: Optional[str] = None,
|
|
600
|
-
user_id: Optional[str] = None,
|
|
601
|
-
):
|
|
602
|
-
"""初始化 ContextAwareQdrantManager
|
|
603
|
-
|
|
604
|
-
Args:
|
|
605
|
-
manager: QdrantManager 实例
|
|
606
|
-
auth_tenant_id: 租户ID,如果为 None 则从 context 获取
|
|
607
|
-
auth_company_id: 公司ID,如果为 None 则从 context 获取
|
|
608
|
-
user_id: 用户ID,如果为 None 则从 context 获取(可选,用于记录)
|
|
609
|
-
"""
|
|
610
|
-
self._manager = manager
|
|
611
|
-
self._auth_tenant_id = auth_tenant_id
|
|
612
|
-
self._auth_company_id = auth_company_id
|
|
613
|
-
self._user_id = user_id
|
|
614
|
-
|
|
615
|
-
def _get_auth_context(self) -> Dict[str, Optional[str]]:
|
|
616
|
-
"""从 context 获取认证信息
|
|
617
|
-
|
|
618
|
-
Returns:
|
|
619
|
-
Dict[str, Optional[str]]: 包含 auth_tenant_id, auth_company_id, user_id 的字典
|
|
620
|
-
"""
|
|
621
|
-
try:
|
|
622
|
-
from ...auth import get_auth_data
|
|
623
|
-
auth_data = get_auth_data()
|
|
624
|
-
if auth_data:
|
|
625
|
-
return {
|
|
626
|
-
'auth_tenant_id': auth_data.auth_tenant_id,
|
|
627
|
-
'auth_company_id': auth_data.current_company,
|
|
628
|
-
'user_id': auth_data.user_id,
|
|
629
|
-
}
|
|
630
|
-
except Exception as e:
|
|
631
|
-
logger.warning(f"Failed to get auth context: {e}")
|
|
632
|
-
|
|
633
|
-
return {
|
|
634
|
-
'auth_tenant_id': None,
|
|
635
|
-
'auth_company_id': None,
|
|
636
|
-
'user_id': None,
|
|
637
|
-
}
|
|
638
|
-
|
|
639
|
-
def _get_tenant_id(self) -> Optional[str]:
|
|
640
|
-
"""获取租户ID,优先使用手动指定的,否则从 context 获取"""
|
|
641
|
-
if self._auth_tenant_id:
|
|
642
|
-
return self._auth_tenant_id
|
|
643
|
-
context = self._get_auth_context()
|
|
644
|
-
return context.get('auth_tenant_id')
|
|
645
|
-
|
|
646
|
-
def _get_company_id(self) -> Optional[str]:
|
|
647
|
-
"""获取公司ID,优先使用手动指定的,否则从 context 获取"""
|
|
648
|
-
if self._auth_company_id:
|
|
649
|
-
return self._auth_company_id
|
|
650
|
-
context = self._get_auth_context()
|
|
651
|
-
return context.get('auth_company_id')
|
|
652
|
-
|
|
653
|
-
def _get_user_id(self) -> Optional[str]:
|
|
654
|
-
"""获取用户ID,优先使用手动指定的,否则从 context 获取"""
|
|
655
|
-
if self._user_id:
|
|
656
|
-
return self._user_id
|
|
657
|
-
context = self._get_auth_context()
|
|
658
|
-
return context.get('user_id')
|
|
659
|
-
|
|
660
|
-
def _build_context_filter(self, existing_filter: Optional[Filter] = None) -> Optional[Filter]:
|
|
661
|
-
"""构建包含 tenant 和 company 的过滤条件
|
|
662
|
-
|
|
663
|
-
Args:
|
|
664
|
-
existing_filter: 已存在的过滤条件
|
|
665
|
-
|
|
666
|
-
Returns:
|
|
667
|
-
Optional[Filter]: 合并后的过滤条件
|
|
668
|
-
"""
|
|
669
|
-
tenant_id = self._get_tenant_id()
|
|
670
|
-
company_id = self._get_company_id()
|
|
671
|
-
|
|
672
|
-
if not tenant_id and not company_id:
|
|
673
|
-
# 如果没有 tenant 和 company,返回原有过滤条件
|
|
674
|
-
return existing_filter
|
|
675
|
-
|
|
676
|
-
conditions = []
|
|
677
|
-
|
|
678
|
-
if tenant_id:
|
|
679
|
-
conditions.append(
|
|
680
|
-
FieldCondition(
|
|
681
|
-
key="auth_tenant_id",
|
|
682
|
-
match=MatchValue(value=tenant_id)
|
|
683
|
-
)
|
|
684
|
-
)
|
|
685
|
-
|
|
686
|
-
if company_id:
|
|
687
|
-
conditions.append(
|
|
688
|
-
FieldCondition(
|
|
689
|
-
key="auth_company_id",
|
|
690
|
-
match=MatchValue(value=company_id)
|
|
691
|
-
)
|
|
692
|
-
)
|
|
693
|
-
|
|
694
|
-
if not conditions:
|
|
695
|
-
return existing_filter
|
|
696
|
-
|
|
697
|
-
# 如果有多个条件,使用 must 组合
|
|
698
|
-
if len(conditions) == 1:
|
|
699
|
-
context_filter = Filter(must=[conditions[0]])
|
|
700
|
-
else:
|
|
701
|
-
context_filter = Filter(must=conditions)
|
|
702
|
-
|
|
703
|
-
# 如果已有过滤条件,需要合并
|
|
704
|
-
if existing_filter:
|
|
705
|
-
if existing_filter.must:
|
|
706
|
-
# 合并 must 条件
|
|
707
|
-
combined_must = existing_filter.must + context_filter.must
|
|
708
|
-
return Filter(
|
|
709
|
-
must=combined_must,
|
|
710
|
-
must_not=existing_filter.must_not,
|
|
711
|
-
should=existing_filter.should,
|
|
712
|
-
)
|
|
713
|
-
else:
|
|
714
|
-
# 如果原过滤条件没有 must,则创建新的 must 列表
|
|
715
|
-
return Filter(
|
|
716
|
-
must=context_filter.must,
|
|
717
|
-
must_not=existing_filter.must_not,
|
|
718
|
-
should=existing_filter.should,
|
|
719
|
-
)
|
|
720
|
-
|
|
721
|
-
return context_filter
|
|
722
|
-
|
|
723
|
-
def _ensure_context_in_payload(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
724
|
-
"""确保 payload 中包含 tenant 和 company 信息
|
|
725
|
-
|
|
726
|
-
Args:
|
|
727
|
-
payload: 点的 payload 字典
|
|
728
|
-
|
|
729
|
-
Returns:
|
|
730
|
-
Dict[str, Any]: 包含 tenant 和 company 的 payload
|
|
731
|
-
"""
|
|
732
|
-
if not isinstance(payload, dict):
|
|
733
|
-
payload = {}
|
|
734
|
-
|
|
735
|
-
tenant_id = self._get_tenant_id()
|
|
736
|
-
company_id = self._get_company_id()
|
|
737
|
-
user_id = self._get_user_id()
|
|
738
|
-
|
|
739
|
-
if tenant_id:
|
|
740
|
-
payload['auth_tenant_id'] = tenant_id
|
|
741
|
-
if company_id:
|
|
742
|
-
payload['auth_company_id'] = company_id
|
|
743
|
-
if user_id:
|
|
744
|
-
payload['auth_user_id'] = user_id
|
|
745
|
-
|
|
746
|
-
return payload
|
|
747
|
-
|
|
748
|
-
# 代理 QdrantManager 的所有方法,并在需要时添加 context 过滤
|
|
749
|
-
|
|
750
|
-
def get_collections(self) -> List[str]:
|
|
751
|
-
"""获取所有集合名称"""
|
|
752
|
-
return self._manager.get_collections()
|
|
753
|
-
|
|
754
|
-
def create_collection(
|
|
755
|
-
self,
|
|
756
|
-
collection_name: str,
|
|
757
|
-
vectors_config: Union[VectorParams, Dict[str, Any]],
|
|
758
|
-
**kwargs
|
|
759
|
-
) -> bool:
|
|
760
|
-
"""创建集合"""
|
|
761
|
-
return self._manager.create_collection(collection_name, vectors_config, **kwargs)
|
|
762
|
-
|
|
763
|
-
def delete_collection(self, collection_name: str) -> bool:
|
|
764
|
-
"""删除集合"""
|
|
765
|
-
return self._manager.delete_collection(collection_name)
|
|
766
|
-
|
|
767
|
-
def collection_exists(self, collection_name: str) -> bool:
|
|
768
|
-
"""检查集合是否存在"""
|
|
769
|
-
return self._manager.collection_exists(collection_name)
|
|
770
|
-
|
|
771
|
-
def upsert_points(
|
|
772
|
-
self,
|
|
773
|
-
collection_name: str,
|
|
774
|
-
points: List[PointStruct],
|
|
775
|
-
**kwargs
|
|
776
|
-
) -> UpdateStatus:
|
|
777
|
-
"""插入或更新点(自动添加 tenant 和 company 信息)"""
|
|
778
|
-
# 确保每个点的 payload 都包含 tenant 和 company
|
|
779
|
-
updated_points = []
|
|
780
|
-
for point in points:
|
|
781
|
-
# 获取现有 payload 或创建新字典
|
|
782
|
-
existing_payload = point.payload if point.payload else {}
|
|
783
|
-
# 确保包含 context 信息
|
|
784
|
-
updated_payload = self._ensure_context_in_payload(
|
|
785
|
-
existing_payload.copy() if isinstance(existing_payload, dict) else {}
|
|
786
|
-
)
|
|
787
|
-
# 创建新的 PointStruct 对象
|
|
788
|
-
updated_point = PointStruct(
|
|
789
|
-
id=point.id,
|
|
790
|
-
vector=point.vector,
|
|
791
|
-
payload=updated_payload,
|
|
792
|
-
)
|
|
793
|
-
updated_points.append(updated_point)
|
|
794
|
-
|
|
795
|
-
return self._manager.upsert_points(collection_name, updated_points, **kwargs)
|
|
796
|
-
|
|
797
|
-
def search_points(
|
|
798
|
-
self,
|
|
799
|
-
collection_name: str,
|
|
800
|
-
query_vector: Union[List[float], str],
|
|
801
|
-
limit: int = 10,
|
|
802
|
-
score_threshold: Optional[float] = None,
|
|
803
|
-
filter: Optional[Filter] = None,
|
|
804
|
-
**kwargs
|
|
805
|
-
) -> List[Any]:
|
|
806
|
-
"""搜索点(自动添加 tenant 和 company 过滤条件)"""
|
|
807
|
-
# 构建包含 context 的过滤条件
|
|
808
|
-
context_filter = self._build_context_filter(filter)
|
|
809
|
-
return self._manager.search_points(
|
|
810
|
-
collection_name=collection_name,
|
|
811
|
-
query_vector=query_vector,
|
|
812
|
-
limit=limit,
|
|
813
|
-
score_threshold=score_threshold,
|
|
814
|
-
filter=context_filter,
|
|
815
|
-
**kwargs
|
|
816
|
-
)
|
|
817
|
-
|
|
818
|
-
def delete_points(
|
|
819
|
-
self,
|
|
820
|
-
collection_name: str,
|
|
821
|
-
points_selector: Union[List[int], Filter],
|
|
822
|
-
**kwargs
|
|
823
|
-
) -> UpdateStatus:
|
|
824
|
-
"""删除点(自动添加 tenant 和 company 过滤条件)"""
|
|
825
|
-
# 如果 points_selector 是 Filter,需要添加 context 过滤
|
|
826
|
-
if isinstance(points_selector, Filter):
|
|
827
|
-
context_filter = self._build_context_filter(points_selector)
|
|
828
|
-
return self._manager.delete_points(collection_name, context_filter, **kwargs)
|
|
829
|
-
else:
|
|
830
|
-
# 如果是 ID 列表,无法添加过滤条件,直接删除
|
|
831
|
-
# 注意:这种情况下不会自动过滤,需要确保调用者知道自己在做什么
|
|
832
|
-
return self._manager.delete_points(collection_name, points_selector, **kwargs)
|
|
833
|
-
|
|
834
|
-
def get_point(self, collection_name: str, point_id: Union[int, str], **kwargs) -> Optional[Any]:
|
|
835
|
-
"""获取单个点"""
|
|
836
|
-
return self._manager.get_point(collection_name, point_id, **kwargs)
|
|
837
|
-
|
|
838
|
-
def close(self):
|
|
839
|
-
"""关闭客户端连接"""
|
|
840
|
-
self._manager.close()
|
|
841
|
-
|
|
842
|
-
def is_connected(self) -> bool:
|
|
843
|
-
"""检查是否已连接"""
|
|
844
|
-
return self._manager.is_connected()
|
|
845
|
-
|
|
846
|
-
def register_hook(self, event: str, hook: Callable):
|
|
847
|
-
"""注册 hook"""
|
|
848
|
-
self._manager.register_hook(event, hook)
|
|
849
|
-
|
|
850
|
-
def unregister_hook(self, event: str, hook: Callable):
|
|
851
|
-
"""取消注册 hook"""
|
|
852
|
-
self._manager.unregister_hook(event, hook)
|
|
853
|
-
|