fiuai-sdk-python 0.6.9__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,853 +0,0 @@
1
- # -- coding: utf-8 --
2
- # Project: fiuai-world
3
- # Created Date: 2025-01-27
4
- # Author: liming
5
- # Email: lmlala@aliyun.com
6
- # Copyright (c) 2025 FiuAI
7
-
8
- import os
9
- import time
10
- import threading
11
- from typing import Optional, Dict, Any, List, Callable, Union
12
- from dataclasses import dataclass
13
-
14
- from qdrant_client import QdrantClient
15
- from qdrant_client.models import (
16
- Distance,
17
- VectorParams,
18
- PointStruct,
19
- Filter,
20
- FieldCondition,
21
- MatchValue,
22
- CollectionStatus,
23
- UpdateStatus,
24
- )
25
- from qdrant_client.http import exceptions as qdrant_exceptions
26
-
27
- from utils import get_logger
28
- from utils.errors import FiuaiBaseError
29
-
30
- logger = get_logger(__name__)
31
-
32
-
33
- class QdrantError(FiuaiBaseError):
34
- """Qdrant 相关错误"""
35
- pass
36
-
37
-
38
- @dataclass
39
- class QdrantConfig:
40
- """Qdrant 配置"""
41
- host: str
42
- port: int = 6333
43
- api_key: Optional[str] = None
44
- timeout: int = 30
45
- retry_count: int = 3
46
- retry_delay: int = 1
47
- prefer_grpc: bool = False
48
- https: bool = False
49
- url: Optional[str] = None # 如果提供 url,则优先使用 url
50
-
51
- @classmethod
52
- def from_env(cls, host: Optional[str] = None, port: Optional[int] = None) -> 'QdrantConfig':
53
- """从环境变量创建配置
54
-
55
- Args:
56
- host: Qdrant 主机,如果为 None 则从环境变量 QDRANT_HOST 读取,默认为 127.0.0.1
57
- port: Qdrant 端口,如果为 None 则从环境变量 QDRANT_PORT 读取,默认为 6333
58
-
59
- Returns:
60
- QdrantConfig: 配置对象
61
- """
62
- return cls(
63
- host=host or os.getenv('QDRANT_HOST', '127.0.0.1'),
64
- port=port or int(os.getenv('QDRANT_PORT', '6333')),
65
- api_key=os.getenv('QDRANT_API_KEY'),
66
- timeout=int(os.getenv('QDRANT_TIMEOUT', '30')),
67
- retry_count=int(os.getenv('QDRANT_RETRY_COUNT', '3')),
68
- retry_delay=int(os.getenv('QDRANT_RETRY_DELAY', '1')),
69
- prefer_grpc=os.getenv('QDRANT_PREFER_GRPC', 'false').lower() == 'true',
70
- https=os.getenv('QDRANT_HTTPS', 'false').lower() == 'true',
71
- url=os.getenv('QDRANT_URL'),
72
- )
73
-
74
-
75
- class QdrantManager:
76
- """Qdrant 客户端管理器单例类
77
-
78
- 提供连接管理、重连、重试等稳定性机制
79
- 提供基础 CRUD 接口和 hook 机制
80
- """
81
- _instance: Optional['QdrantManager'] = None
82
- _lock = threading.Lock()
83
- _initialized = False
84
-
85
- def __new__(cls):
86
- if cls._instance is None:
87
- with cls._lock:
88
- if cls._instance is None:
89
- cls._instance = super().__new__(cls)
90
- return cls._instance
91
-
92
- def __init__(self):
93
- if hasattr(self, '_initialized') and self._initialized:
94
- return
95
-
96
- self._config: Optional[QdrantConfig] = None
97
- self._client: Optional[QdrantClient] = None
98
- self._is_connected = False
99
- self._last_check_time = 0
100
- self._check_interval = 30 # 连接检查间隔(秒)
101
- self._hooks: Dict[str, List[Callable]] = {
102
- 'before_query': [],
103
- 'after_query': [],
104
- 'before_create_collection': [],
105
- 'after_create_collection': [],
106
- 'before_delete_collection': [],
107
- 'after_delete_collection': [],
108
- }
109
- self._initialized = True
110
-
111
- def initialize(self, config: QdrantConfig):
112
- """初始化 Qdrant 客户端
113
-
114
- Args:
115
- config: Qdrant 配置
116
-
117
- Raises:
118
- QdrantError: 初始化失败时抛出
119
- """
120
- if self._is_connected:
121
- logger.warning("Qdrant 已经初始化,跳过重复初始化")
122
- return
123
-
124
- self._config = config
125
-
126
- try:
127
- # 临时清除代理环境变量,强制禁用代理
128
- proxy_env_vars = [
129
- 'HTTP_PROXY', 'HTTPS_PROXY', 'http_proxy', 'https_proxy',
130
- 'ALL_PROXY', 'all_proxy', 'SOCKS_PROXY', 'socks_proxy',
131
- 'NO_PROXY', 'no_proxy'
132
- ]
133
- saved_proxy_vars = {}
134
- for var in proxy_env_vars:
135
- if var in os.environ:
136
- saved_proxy_vars[var] = os.environ.pop(var)
137
-
138
- try:
139
- # 创建客户端
140
- if config.url:
141
- # 使用 URL 连接
142
- self._client = QdrantClient(
143
- url=config.url,
144
- api_key=config.api_key,
145
- timeout=config.timeout,
146
- prefer_grpc=config.prefer_grpc,
147
- )
148
- else:
149
- # 使用 host:port 连接
150
- self._client = QdrantClient(
151
- host=config.host,
152
- port=config.port,
153
- api_key=config.api_key,
154
- timeout=config.timeout,
155
- prefer_grpc=config.prefer_grpc,
156
- https=config.https,
157
- )
158
-
159
- # 测试连接
160
- self._client.get_collections()
161
-
162
- self._is_connected = True
163
- self._last_check_time = time.time()
164
-
165
- logger.info(f"Qdrant 客户端初始化成功: {config.url or f'{config.host}:{config.port}'}")
166
- finally:
167
- # 恢复代理环境变量
168
- for var, value in saved_proxy_vars.items():
169
- os.environ[var] = value
170
-
171
- except qdrant_exceptions.UnexpectedResponse as e:
172
- self._is_connected = False
173
- self._client = None
174
- logger.error(f"Qdrant 初始化失败 - 响应错误: {str(e)}")
175
- raise QdrantError(f"初始化失败 - 响应错误: {str(e)}")
176
- except qdrant_exceptions.ResponseHandlingException as e:
177
- self._is_connected = False
178
- self._client = None
179
- logger.error(f"Qdrant 初始化失败 - 响应处理错误: {str(e)}")
180
- raise QdrantError(f"初始化失败 - 响应处理错误: {str(e)}")
181
- except Exception as e:
182
- self._is_connected = False
183
- self._client = None
184
- logger.error(f"Qdrant 初始化失败: {str(e)}")
185
- raise QdrantError(f"初始化失败: {str(e)}")
186
-
187
- def _check_connection(self) -> bool:
188
- """检查连接是否有效
189
-
190
- Returns:
191
- bool: 连接是否有效
192
- """
193
- if not self._client or not self._is_connected:
194
- return False
195
-
196
- # 避免频繁检查
197
- current_time = time.time()
198
- if current_time - self._last_check_time < self._check_interval:
199
- return True
200
-
201
- try:
202
- # 执行简单查询测试连接
203
- self._client.get_collections()
204
- self._last_check_time = current_time
205
- return True
206
-
207
- except Exception as e:
208
- logger.warning(f"连接检查失败: {str(e)}")
209
- self._is_connected = False
210
- return False
211
-
212
- def _reconnect(self) -> bool:
213
- """重新连接
214
-
215
- Returns:
216
- bool: 重连是否成功
217
- """
218
- if not self._config:
219
- return False
220
-
221
- logger.info("尝试重新连接 Qdrant...")
222
-
223
- try:
224
- # 关闭旧连接
225
- if self._client:
226
- self._client = None
227
-
228
- # 重新初始化
229
- self.initialize(self._config)
230
- return True
231
-
232
- except Exception as e:
233
- logger.error(f"重连失败: {str(e)}")
234
- return False
235
-
236
- def _get_client(self) -> QdrantClient:
237
- """获取客户端(带重连机制)
238
-
239
- Returns:
240
- QdrantClient: Qdrant 客户端对象
241
-
242
- Raises:
243
- QdrantError: 获取客户端失败时抛出
244
- """
245
- if not self._check_connection():
246
- if not self._reconnect():
247
- raise QdrantError("无法连接到 Qdrant,请检查配置和网络")
248
-
249
- return self._client
250
-
251
- def _execute_with_retry(self, func: Callable, *args, retry_count: Optional[int] = None, **kwargs) -> Any:
252
- """执行操作(带重试机制)
253
-
254
- Args:
255
- func: 要执行的函数
256
- *args: 函数位置参数
257
- retry_count: 重试次数,None 则使用配置中的值
258
- **kwargs: 函数关键字参数
259
-
260
- Returns:
261
- Any: 函数执行结果
262
-
263
- Raises:
264
- QdrantError: 执行失败时抛出
265
- """
266
- if retry_count is None:
267
- retry_count = self._config.retry_count if self._config else 3
268
-
269
- last_error = None
270
- for attempt in range(retry_count + 1):
271
- try:
272
- client = self._get_client()
273
- result = func(client, *args, **kwargs)
274
- return result
275
-
276
- except qdrant_exceptions.UnexpectedResponse as e:
277
- last_error = QdrantError(f"响应错误: {str(e)}")
278
- if attempt < retry_count:
279
- time.sleep(self._config.retry_delay if self._config else 1)
280
- self._is_connected = False
281
- continue
282
-
283
- except qdrant_exceptions.ResponseHandlingException as e:
284
- last_error = QdrantError(f"响应处理错误: {str(e)}")
285
- if attempt < retry_count:
286
- time.sleep(self._config.retry_delay if self._config else 1)
287
- self._is_connected = False
288
- continue
289
-
290
- except Exception as e:
291
- last_error = QdrantError(f"执行操作时发生错误: {str(e)}")
292
- # 如果是连接相关错误,尝试重连
293
- error_msg = str(e).lower()
294
- if "connection" in error_msg or "timeout" in error_msg or "network" in error_msg:
295
- self._is_connected = False
296
- if attempt < retry_count:
297
- time.sleep(self._config.retry_delay if self._config else 1)
298
- continue
299
- elif attempt < retry_count:
300
- time.sleep(self._config.retry_delay if self._config else 1)
301
- continue
302
-
303
- raise last_error or QdrantError("操作执行失败")
304
-
305
- def register_hook(self, event: str, hook: Callable):
306
- """注册 hook
307
-
308
- Args:
309
- event: 事件名称
310
- hook: hook 函数
311
- """
312
- if event not in self._hooks:
313
- raise ValueError(f"未知的事件类型: {event}")
314
-
315
- self._hooks[event].append(hook)
316
- logger.debug(f"注册 hook: {event}")
317
-
318
- def unregister_hook(self, event: str, hook: Callable):
319
- """取消注册 hook
320
-
321
- Args:
322
- event: 事件名称
323
- hook: hook 函数
324
- """
325
- if event in self._hooks and hook in self._hooks[event]:
326
- self._hooks[event].remove(hook)
327
- logger.debug(f"取消注册 hook: {event}")
328
-
329
- def get_collections(self) -> List[str]:
330
- """获取所有集合名称
331
-
332
- Returns:
333
- List[str]: 集合名称列表
334
- """
335
- if not self._is_connected:
336
- raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
337
-
338
- def _get_collections(client):
339
- collections = client.get_collections()
340
- return [col.name for col in collections.collections]
341
-
342
- return self._execute_with_retry(_get_collections)
343
-
344
- def create_collection(
345
- self,
346
- collection_name: str,
347
- vectors_config: Union[VectorParams, Dict[str, Any]],
348
- **kwargs
349
- ) -> bool:
350
- """创建集合
351
-
352
- Args:
353
- collection_name: 集合名称
354
- vectors_config: 向量配置
355
- **kwargs: 其他参数
356
-
357
- Returns:
358
- bool: 是否创建成功
359
- """
360
- if not self._is_connected:
361
- raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
362
-
363
- # 执行 hook
364
- for hook in self._hooks.get('before_create_collection', []):
365
- hook(collection_name, vectors_config, **kwargs)
366
-
367
- def _create_collection(client, name, config, **kw):
368
- client.create_collection(
369
- collection_name=name,
370
- vectors_config=config,
371
- **kw
372
- )
373
- return True
374
-
375
- result = self._execute_with_retry(_create_collection, collection_name, vectors_config, **kwargs)
376
-
377
- # 执行 hook
378
- for hook in self._hooks.get('after_create_collection', []):
379
- hook(collection_name, result)
380
-
381
- return result
382
-
383
- def delete_collection(self, collection_name: str) -> bool:
384
- """删除集合
385
-
386
- Args:
387
- collection_name: 集合名称
388
-
389
- Returns:
390
- bool: 是否删除成功
391
- """
392
- if not self._is_connected:
393
- raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
394
-
395
- # 执行 hook
396
- for hook in self._hooks.get('before_delete_collection', []):
397
- hook(collection_name)
398
-
399
- def _delete_collection(client, name):
400
- client.delete_collection(collection_name=name)
401
- return True
402
-
403
- result = self._execute_with_retry(_delete_collection, collection_name)
404
-
405
- # 执行 hook
406
- for hook in self._hooks.get('after_delete_collection', []):
407
- hook(collection_name, result)
408
-
409
- return result
410
-
411
- def collection_exists(self, collection_name: str) -> bool:
412
- """检查集合是否存在
413
-
414
- Args:
415
- collection_name: 集合名称
416
-
417
- Returns:
418
- bool: 集合是否存在
419
- """
420
- if not self._is_connected:
421
- raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
422
-
423
- def _collection_exists(client, name):
424
- try:
425
- client.get_collection(name)
426
- return True
427
- except qdrant_exceptions.UnexpectedResponse:
428
- return False
429
-
430
- return self._execute_with_retry(_collection_exists, collection_name)
431
-
432
- def upsert_points(
433
- self,
434
- collection_name: str,
435
- points: List[PointStruct],
436
- **kwargs
437
- ) -> UpdateStatus:
438
- """插入或更新点
439
-
440
- Args:
441
- collection_name: 集合名称
442
- points: 点列表
443
- **kwargs: 其他参数
444
-
445
- Returns:
446
- UpdateStatus: 更新状态
447
- """
448
- if not self._is_connected:
449
- raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
450
-
451
- def _upsert_points(client, name, pts, **kw):
452
- return client.upsert(
453
- collection_name=name,
454
- points=pts,
455
- **kw
456
- )
457
-
458
- return self._execute_with_retry(_upsert_points, collection_name, points, **kwargs)
459
-
460
- def search_points(
461
- self,
462
- collection_name: str,
463
- query_vector: Union[List[float], str],
464
- limit: int = 10,
465
- score_threshold: Optional[float] = None,
466
- filter: Optional[Filter] = None,
467
- **kwargs
468
- ) -> List[Any]:
469
- """搜索点
470
-
471
- Args:
472
- collection_name: 集合名称
473
- query_vector: 查询向量或命名向量
474
- limit: 返回结果数量
475
- score_threshold: 分数阈值
476
- filter: 过滤条件
477
- **kwargs: 其他参数
478
-
479
- Returns:
480
- List[Any]: 搜索结果列表
481
- """
482
- if not self._is_connected:
483
- raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
484
-
485
- # 执行 hook
486
- for hook in self._hooks.get('before_query', []):
487
- query_vector = hook(collection_name, query_vector) or query_vector
488
-
489
- def _search_points(client, name, qv, lim, st, flt, **kw):
490
- return client.search(
491
- collection_name=name,
492
- query_vector=qv,
493
- limit=lim,
494
- score_threshold=st,
495
- query_filter=flt,
496
- **kw
497
- )
498
-
499
- result = self._execute_with_retry(
500
- _search_points,
501
- collection_name,
502
- query_vector,
503
- limit,
504
- score_threshold,
505
- filter,
506
- **kwargs
507
- )
508
-
509
- # 执行 hook
510
- for hook in self._hooks.get('after_query', []):
511
- hook(collection_name, query_vector, result)
512
-
513
- return result
514
-
515
- def delete_points(
516
- self,
517
- collection_name: str,
518
- points_selector: Union[List[int], Filter],
519
- **kwargs
520
- ) -> UpdateStatus:
521
- """删除点
522
-
523
- Args:
524
- collection_name: 集合名称
525
- points_selector: 点选择器(ID列表或过滤条件)
526
- **kwargs: 其他参数
527
-
528
- Returns:
529
- UpdateStatus: 更新状态
530
- """
531
- if not self._is_connected:
532
- raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
533
-
534
- def _delete_points(client, name, selector, **kw):
535
- return client.delete(
536
- collection_name=name,
537
- points_selector=selector,
538
- **kw
539
- )
540
-
541
- return self._execute_with_retry(_delete_points, collection_name, points_selector, **kwargs)
542
-
543
- def get_point(self, collection_name: str, point_id: Union[int, str], **kwargs) -> Optional[Any]:
544
- """获取单个点
545
-
546
- Args:
547
- collection_name: 集合名称
548
- point_id: 点ID
549
- **kwargs: 其他参数
550
-
551
- Returns:
552
- Optional[Any]: 点对象,不存在则返回 None
553
- """
554
- if not self._is_connected:
555
- raise QdrantError("Qdrant 未初始化,请先调用 initialize()")
556
-
557
- def _get_point(client, name, pid, **kw):
558
- result = client.retrieve(
559
- collection_name=name,
560
- ids=[pid],
561
- **kw
562
- )
563
- return result[0] if result else None
564
-
565
- return self._execute_with_retry(_get_point, collection_name, point_id, **kwargs)
566
-
567
- def close(self):
568
- """关闭客户端连接"""
569
- if self._client:
570
- # QdrantClient 没有显式的 close 方法,设置为 None 即可
571
- self._client = None
572
- self._is_connected = False
573
- logger.info("Qdrant 客户端已关闭")
574
-
575
- def is_connected(self) -> bool:
576
- """检查是否已连接
577
-
578
- Returns:
579
- bool: 是否已连接
580
- """
581
- return self._is_connected and self._check_connection()
582
-
583
-
584
- # 创建全局单例实例
585
- qdrant_manager = QdrantManager()
586
-
587
-
588
- class ContextAwareQdrantManager:
589
- """支持从 context 自动获取 tenant 和 company 的 Qdrant 管理器包装类
590
-
591
- 在查询和更新操作时自动添加 tenant 和 company 过滤条件
592
- 支持手动指定 tenant 和 company(覆盖 context 中的值)
593
- """
594
-
595
- def __init__(
596
- self,
597
- manager: QdrantManager,
598
- auth_tenant_id: Optional[str] = None,
599
- auth_company_id: Optional[str] = None,
600
- user_id: Optional[str] = None,
601
- ):
602
- """初始化 ContextAwareQdrantManager
603
-
604
- Args:
605
- manager: QdrantManager 实例
606
- auth_tenant_id: 租户ID,如果为 None 则从 context 获取
607
- auth_company_id: 公司ID,如果为 None 则从 context 获取
608
- user_id: 用户ID,如果为 None 则从 context 获取(可选,用于记录)
609
- """
610
- self._manager = manager
611
- self._auth_tenant_id = auth_tenant_id
612
- self._auth_company_id = auth_company_id
613
- self._user_id = user_id
614
-
615
- def _get_auth_context(self) -> Dict[str, Optional[str]]:
616
- """从 context 获取认证信息
617
-
618
- Returns:
619
- Dict[str, Optional[str]]: 包含 auth_tenant_id, auth_company_id, user_id 的字典
620
- """
621
- try:
622
- from pkg.context import get_auth_data
623
- auth_data = get_auth_data()
624
- if auth_data:
625
- return {
626
- 'auth_tenant_id': auth_data.auth_tenant_id,
627
- 'auth_company_id': auth_data.current_company,
628
- 'user_id': auth_data.user_id,
629
- }
630
- except Exception as e:
631
- logger.warning(f"Failed to get auth context: {e}")
632
-
633
- return {
634
- 'auth_tenant_id': None,
635
- 'auth_company_id': None,
636
- 'user_id': None,
637
- }
638
-
639
- def _get_tenant_id(self) -> Optional[str]:
640
- """获取租户ID,优先使用手动指定的,否则从 context 获取"""
641
- if self._auth_tenant_id:
642
- return self._auth_tenant_id
643
- context = self._get_auth_context()
644
- return context.get('auth_tenant_id')
645
-
646
- def _get_company_id(self) -> Optional[str]:
647
- """获取公司ID,优先使用手动指定的,否则从 context 获取"""
648
- if self._auth_company_id:
649
- return self._auth_company_id
650
- context = self._get_auth_context()
651
- return context.get('auth_company_id')
652
-
653
- def _get_user_id(self) -> Optional[str]:
654
- """获取用户ID,优先使用手动指定的,否则从 context 获取"""
655
- if self._user_id:
656
- return self._user_id
657
- context = self._get_auth_context()
658
- return context.get('user_id')
659
-
660
- def _build_context_filter(self, existing_filter: Optional[Filter] = None) -> Optional[Filter]:
661
- """构建包含 tenant 和 company 的过滤条件
662
-
663
- Args:
664
- existing_filter: 已存在的过滤条件
665
-
666
- Returns:
667
- Optional[Filter]: 合并后的过滤条件
668
- """
669
- tenant_id = self._get_tenant_id()
670
- company_id = self._get_company_id()
671
-
672
- if not tenant_id and not company_id:
673
- # 如果没有 tenant 和 company,返回原有过滤条件
674
- return existing_filter
675
-
676
- conditions = []
677
-
678
- if tenant_id:
679
- conditions.append(
680
- FieldCondition(
681
- key="auth_tenant_id",
682
- match=MatchValue(value=tenant_id)
683
- )
684
- )
685
-
686
- if company_id:
687
- conditions.append(
688
- FieldCondition(
689
- key="auth_company_id",
690
- match=MatchValue(value=company_id)
691
- )
692
- )
693
-
694
- if not conditions:
695
- return existing_filter
696
-
697
- # 如果有多个条件,使用 must 组合
698
- if len(conditions) == 1:
699
- context_filter = Filter(must=[conditions[0]])
700
- else:
701
- context_filter = Filter(must=conditions)
702
-
703
- # 如果已有过滤条件,需要合并
704
- if existing_filter:
705
- if existing_filter.must:
706
- # 合并 must 条件
707
- combined_must = existing_filter.must + context_filter.must
708
- return Filter(
709
- must=combined_must,
710
- must_not=existing_filter.must_not,
711
- should=existing_filter.should,
712
- )
713
- else:
714
- # 如果原过滤条件没有 must,则创建新的 must 列表
715
- return Filter(
716
- must=context_filter.must,
717
- must_not=existing_filter.must_not,
718
- should=existing_filter.should,
719
- )
720
-
721
- return context_filter
722
-
723
- def _ensure_context_in_payload(self, payload: Dict[str, Any]) -> Dict[str, Any]:
724
- """确保 payload 中包含 tenant 和 company 信息
725
-
726
- Args:
727
- payload: 点的 payload 字典
728
-
729
- Returns:
730
- Dict[str, Any]: 包含 tenant 和 company 的 payload
731
- """
732
- if not isinstance(payload, dict):
733
- payload = {}
734
-
735
- tenant_id = self._get_tenant_id()
736
- company_id = self._get_company_id()
737
- user_id = self._get_user_id()
738
-
739
- if tenant_id:
740
- payload['auth_tenant_id'] = tenant_id
741
- if company_id:
742
- payload['auth_company_id'] = company_id
743
- if user_id:
744
- payload['auth_user_id'] = user_id
745
-
746
- return payload
747
-
748
- # 代理 QdrantManager 的所有方法,并在需要时添加 context 过滤
749
-
750
- def get_collections(self) -> List[str]:
751
- """获取所有集合名称"""
752
- return self._manager.get_collections()
753
-
754
- def create_collection(
755
- self,
756
- collection_name: str,
757
- vectors_config: Union[VectorParams, Dict[str, Any]],
758
- **kwargs
759
- ) -> bool:
760
- """创建集合"""
761
- return self._manager.create_collection(collection_name, vectors_config, **kwargs)
762
-
763
- def delete_collection(self, collection_name: str) -> bool:
764
- """删除集合"""
765
- return self._manager.delete_collection(collection_name)
766
-
767
- def collection_exists(self, collection_name: str) -> bool:
768
- """检查集合是否存在"""
769
- return self._manager.collection_exists(collection_name)
770
-
771
- def upsert_points(
772
- self,
773
- collection_name: str,
774
- points: List[PointStruct],
775
- **kwargs
776
- ) -> UpdateStatus:
777
- """插入或更新点(自动添加 tenant 和 company 信息)"""
778
- # 确保每个点的 payload 都包含 tenant 和 company
779
- updated_points = []
780
- for point in points:
781
- # 获取现有 payload 或创建新字典
782
- existing_payload = point.payload if point.payload else {}
783
- # 确保包含 context 信息
784
- updated_payload = self._ensure_context_in_payload(
785
- existing_payload.copy() if isinstance(existing_payload, dict) else {}
786
- )
787
- # 创建新的 PointStruct 对象
788
- updated_point = PointStruct(
789
- id=point.id,
790
- vector=point.vector,
791
- payload=updated_payload,
792
- )
793
- updated_points.append(updated_point)
794
-
795
- return self._manager.upsert_points(collection_name, updated_points, **kwargs)
796
-
797
- def search_points(
798
- self,
799
- collection_name: str,
800
- query_vector: Union[List[float], str],
801
- limit: int = 10,
802
- score_threshold: Optional[float] = None,
803
- filter: Optional[Filter] = None,
804
- **kwargs
805
- ) -> List[Any]:
806
- """搜索点(自动添加 tenant 和 company 过滤条件)"""
807
- # 构建包含 context 的过滤条件
808
- context_filter = self._build_context_filter(filter)
809
- return self._manager.search_points(
810
- collection_name=collection_name,
811
- query_vector=query_vector,
812
- limit=limit,
813
- score_threshold=score_threshold,
814
- filter=context_filter,
815
- **kwargs
816
- )
817
-
818
- def delete_points(
819
- self,
820
- collection_name: str,
821
- points_selector: Union[List[int], Filter],
822
- **kwargs
823
- ) -> UpdateStatus:
824
- """删除点(自动添加 tenant 和 company 过滤条件)"""
825
- # 如果 points_selector 是 Filter,需要添加 context 过滤
826
- if isinstance(points_selector, Filter):
827
- context_filter = self._build_context_filter(points_selector)
828
- return self._manager.delete_points(collection_name, context_filter, **kwargs)
829
- else:
830
- # 如果是 ID 列表,无法添加过滤条件,直接删除
831
- # 注意:这种情况下不会自动过滤,需要确保调用者知道自己在做什么
832
- return self._manager.delete_points(collection_name, points_selector, **kwargs)
833
-
834
- def get_point(self, collection_name: str, point_id: Union[int, str], **kwargs) -> Optional[Any]:
835
- """获取单个点"""
836
- return self._manager.get_point(collection_name, point_id, **kwargs)
837
-
838
- def close(self):
839
- """关闭客户端连接"""
840
- self._manager.close()
841
-
842
- def is_connected(self) -> bool:
843
- """检查是否已连接"""
844
- return self._manager.is_connected()
845
-
846
- def register_hook(self, event: str, hook: Callable):
847
- """注册 hook"""
848
- self._manager.register_hook(event, hook)
849
-
850
- def unregister_hook(self, event: str, hook: Callable):
851
- """取消注册 hook"""
852
- self._manager.unregister_hook(event, hook)
853
-