crewplus 0.2.89__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,917 @@
1
+ # -*- coding: utf-8 -*-
2
+ # @Author: Cursor
3
+ # @Date: 2025-02-12
4
+ # @Last Modified by: Gemini
5
+ # @Last Modified time: 2025-10-09
6
+
7
+ import asyncio
8
+ import logging
9
+ import time
10
+ import uuid
11
+ from collections import defaultdict
12
+ from typing import List, Dict, Optional
13
+
14
+ from langchain_core.documents import Document
15
+ from langchain_core.embeddings import Embeddings
16
+ from langchain_milvus import Milvus
17
+ from langchain_openai import AzureOpenAIEmbeddings
18
+ from pymilvus import MilvusClient, AsyncMilvusClient, connections
19
+
20
+ from .schema_milvus import SchemaMilvus, DEFAULT_SCHEMA
21
+ from ...services.init_services import get_model_balancer
22
+
23
+
24
+ #from .milvus_schema_manager import MilvusSchemaManager
25
+
26
+ class VDBService(object):
27
+ """
28
+ A service to manage connections to Milvus/Zilliz vector databases and embedding models.
29
+
30
+ This service centralizes the configuration and instantiation of the Milvus client
31
+ and provides helper methods to get embedding functions and vector store instances.
32
+
33
+ This service generates a unique connection `alias` upon initialization. This `alias`
34
+ is propagated to all Milvus clients created by this service, including those
35
+ within `langchain_milvus` instances. This mechanism ensures that a single,
36
+ shared connection is used for all operations, preventing the creation of
37
+ multiple redundant connections and improving resource efficiency.
38
+
39
+ Args:
40
+ settings (dict, optional): A dictionary containing configuration for the vector store
41
+ and embedding models.
42
+ endpoint (str, optional): The URI for the Zilliz cluster. Can be used for simple
43
+ initialization instead of `settings`.
44
+ token (str, optional): The token for authenticating with Zilliz. Must be provided
45
+ with `endpoint`.
46
+ schema (str, optional): The schema definition for a collection. Defaults to None.
47
+ logger (logging.Logger, optional): An optional logger instance. Defaults to None.
48
+
49
+ Raises:
50
+ ValueError: If required configurations are missing.
51
+ NotImplementedError: If an unsupported provider is specified.
52
+ RuntimeError: If the MilvusClient fails to initialize after a retry.
53
+
54
+ Example:
55
+ >>> # Initialize with a full settings dictionary
56
+ >>> settings = {
57
+ ... "embedder": {
58
+ ... "provider": "azure-openai-embeddings",
59
+ ... "config": {
60
+ ... "model": "text-embedding-ada-002",
61
+ ... "api_version": "2023-05-15",
62
+ ... "api_key": "YOUR_AZURE_OPENAI_KEY",
63
+ ... "openai_base_url": "YOUR_AZURE_OPENAI_ENDPOINT",
64
+ ... "embedding_dims": 1536
65
+ ... }
66
+ ... },
67
+ ... "vector_store": {
68
+ ... "provider": "milvus",
69
+ ... "config": {
70
+ ... "host": "localhost",
71
+ ... "port": 19530,
72
+ ... "user": "root",
73
+ ... "password": "password",
74
+ ... "db_name": "default"
75
+ ... }
76
+ ... },
77
+ ... "index_params": {
78
+ ... "metric_type": "IP",
79
+ ... "index_type": "HNSW",
80
+ ... "params": {}
81
+ ... }
82
+ ... }
83
+ >>> vdb_service = VDBService(settings=settings)
84
+ >>>
85
+ >>> # Alternatively, initialize with an endpoint and token for Zilliz
86
+ >>> # vdb_service_zilliz = VDBService(endpoint="YOUR_ZILLIZ_ENDPOINT", token="YOUR_ZILLIZ_TOKEN")
87
+ >>>
88
+ >>> # Get the raw Milvus client
89
+ >>> client = vdb_service.get_vector_client()
90
+ >>> print(client.list_collections())
91
+ >>> # Get an embedding function
92
+ >>> embeddings = vdb_service.get_embeddings()
93
+ >>> print(embeddings)
94
+ >>> # Get a LangChain vector store instance (will be cached)
95
+ >>> vector_store = vdb_service.get_vector_store(collection_name="my_collection")
96
+ >>> print(vector_store)
97
+ >>> same_vector_store = vdb_service.get_vector_store(collection_name="my_collection")
98
+ >>> assert vector_store is same_vector_store
99
+ """
100
+ _client: MilvusClient
101
+ _async_client: Optional[AsyncMilvusClient] = None
102
+ _instances: Dict[str, Milvus] = {}
103
+ _async_instances: Dict[str, Milvus] = {}
104
+ _async_instance_locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
105
+
106
+ schema: str
107
+ embedding_function: Embeddings
108
+ index_params: dict
109
+ connection_args: dict
110
+ settings: dict
111
+
112
+ def __init__(self, settings: dict = None, endpoint: str = None, token: str = None, schema: str = None, logger: logging.Logger = None):
113
+ """
114
+ Initializes the VDBService.
115
+
116
+ Can be initialized in two ways:
117
+ 1. By providing a full `settings` dictionary for complex configurations.
118
+ 2. By providing `endpoint` and `token` for a direct Zilliz connection.
119
+ Note: When using this method, an `embedder` configuration is not created.
120
+ You must either use the `ModelLoadBalancer` or pass an `Embeddings` object
121
+ directly to methods like `get_vector_store`.
122
+
123
+ Args:
124
+ settings (dict, optional): Configuration dictionary for the service. Defaults to None.
125
+ endpoint (str, optional): The URI for the Zilliz cluster. Used if `settings` is not provided.
126
+ token (str, optional): The token for authenticating with the Zilliz cluster.
127
+ schema (str, optional): Default schema for new collections. Defaults to None.
128
+ logger (logging.Logger, optional): Logger instance. Defaults to None.
129
+ """
130
+ self.logger = logger or logging.getLogger(__name__)
131
+ self.collection_schema = None
132
+
133
+ if settings:
134
+ self.settings = settings
135
+ elif endpoint and token:
136
+ self.logger.info("Initializing VDBService with endpoint and token for a Zilliz connection.")
137
+ self.settings = {
138
+ "vector_store": {
139
+ "provider": "zilliz",
140
+ "config": {
141
+ "uri": endpoint,
142
+ "token": token
143
+ }
144
+ }
145
+ }
146
+ else:
147
+ raise ValueError("VDBService must be initialized with either a 'settings' dictionary or both 'endpoint' and 'token'.")
148
+
149
+ vector_store_settings = self.settings.get("vector_store")
150
+ if not vector_store_settings:
151
+ msg = "'vector_store' not found in settings"
152
+ self.logger.error(msg)
153
+ raise ValueError(msg)
154
+
155
+ provider = vector_store_settings.get("provider")
156
+ self.connection_args = vector_store_settings.get("config")
157
+
158
+ if not provider or not self.connection_args:
159
+ msg = "'provider' or 'config' not found in 'vector_store' settings"
160
+ self.logger.error(msg)
161
+ raise ValueError(msg)
162
+
163
+ self._provider = provider # Store provider for lazy initialization
164
+
165
+ # Create separate aliases for sync and async clients to avoid connection handler race conditions.
166
+ self.sync_alias = f"crewplus-vdb-sync-{uuid.uuid4()}"
167
+ self.async_alias = f"crewplus-vdb-async-{uuid.uuid4()}"
168
+
169
+ # The default alias in connection_args should be the sync one, as langchain_milvus
170
+ # primarily uses a synchronous client and will pick up this alias.
171
+ self.connection_args['alias'] = self.sync_alias
172
+
173
+ self._client = self._initialize_milvus_client(provider)
174
+ # lazy-initialize async milvus
175
+ # self._async_client = self._initialize_async_milvus_client(provider)
176
+
177
+ # Do not initialize the async client here.
178
+ # It must be lazily initialized within an async context.
179
+ self._async_client: Optional[AsyncMilvusClient] = None
180
+
181
+ self.schema = schema
182
+ self.index_params = self.settings.get("index_params")
183
+
184
+ #self.schema_manager = MilvusSchemaManager(client=self._client, async_client=self._async_client)
185
+
186
+ self.logger.info("VDBService initialized successfully")
187
+
188
+ def _get_milvus_client_args(self, provider: str) -> dict:
189
+ """
190
+ Constructs the arguments for Milvus/AsyncMilvus client initialization based on the provider.
191
+ """
192
+ if provider == "milvus":
193
+ host = self.connection_args.get("host", "localhost")
194
+ port = self.connection_args.get("port", 19530)
195
+
196
+ # Use https for remote hosts, and http for local connections.
197
+ scheme = "https" if host not in ["localhost", "127.0.0.1"] else "http"
198
+ uri = f"{scheme}://{host}:{port}"
199
+
200
+ client_args = {
201
+ "uri": uri,
202
+ "user": self.connection_args.get("user"),
203
+ "password": self.connection_args.get("password"),
204
+ "db_name": self.connection_args.get("db_name"),
205
+ }
206
+ return {k: v for k, v in client_args.items() if v is not None}
207
+
208
+ elif provider == "zilliz":
209
+ # Return a copy without the default alias, as it will be added specifically for sync/async clients.
210
+ zilliz_args = self.connection_args.copy()
211
+ zilliz_args.pop('alias', None)
212
+ # 增加 gRPC keepalive 选项来加固连接
213
+ zilliz_args['channel_options'] = [
214
+ ('grpc.keepalive_time_ms', 60000), # 每 60 秒发送一次 ping
215
+ ('grpc.keepalive_timeout_ms', 20000), # 20 秒内没收到 pong 则认为连接断开
216
+ ('grpc.enable_http_proxy', 0),
217
+ ]
218
+ return zilliz_args
219
+ else:
220
+ self.logger.error(f"Unsupported vector store provider: {provider}")
221
+ raise NotImplementedError(f"Vector store provider '{provider}' is not supported.")
222
+
223
+ def _initialize_milvus_client(self, provider: str) -> MilvusClient:
224
+ """
225
+ Initializes and returns a MilvusClient with a retry mechanism.
226
+ """
227
+ client_args = self._get_milvus_client_args(provider)
228
+ client_args["alias"] = self.sync_alias
229
+
230
+ try:
231
+ # First attempt to connect
232
+ return MilvusClient(**client_args)
233
+ except Exception as e:
234
+ self.logger.error(f"Failed to initialize MilvusClient, trying again. Error: {e}")
235
+ # Second attempt after failure
236
+ try:
237
+ return MilvusClient(**client_args)
238
+ except Exception as e_retry:
239
+ self.logger.error(f"Failed to initialize MilvusClient on retry. Final error: {e_retry}")
240
+ raise RuntimeError(f"Could not initialize MilvusClient after retry: {e_retry}")
241
+
242
+ def _initialize_async_milvus_client(self, provider: str) -> AsyncMilvusClient:
243
+ """
244
+ Initializes and returns an AsyncMilvusClient with a retry mechanism.
245
+ """
246
+ client_args = self._get_milvus_client_args(provider)
247
+ client_args["alias"] = self.async_alias
248
+
249
+ try:
250
+ return AsyncMilvusClient(**client_args)
251
+ except Exception as e:
252
+ self.logger.error(f"Failed to initialize AsyncMilvusClient, trying again. Error: {e}")
253
+ time.sleep(1) # sync sleep is fine, we are in a thread
254
+ try:
255
+ return AsyncMilvusClient(**client_args)
256
+ except Exception as e_retry:
257
+ self.logger.error(f"Failed to initialize AsyncMilvusClient on retry. Final error: {e_retry}")
258
+ raise RuntimeError(f"Could not initialize AsyncMilvusClient after retry: {e_retry}") from e_retry
259
+
260
+ def get_vector_client(self) -> MilvusClient:
261
+ """
262
+ Returns the active MilvusClient instance, initializing it if necessary.
263
+
264
+ Returns:
265
+ MilvusClient: The initialized client for interacting with the vector database.
266
+ """
267
+ if self._client is None:
268
+ self.logger.debug("Initializing synchronous MilvusClient...")
269
+ self._client = self._initialize_milvus_client(self._provider)
270
+
271
+ return self._client
272
+
273
+ async def aget_async_vector_client(self) -> AsyncMilvusClient:
274
+ """
275
+ Lazily initializes and returns the AsyncMilvusClient.
276
+ This ensures the client is created within the running event loop.
277
+ """
278
+ if self._async_client is None:
279
+ self.logger.info("Lazily initializing AsyncMilvusClient...")
280
+ client_args = self._get_milvus_client_args(self._provider)
281
+ # Use the dedicated async alias
282
+ client_args['alias'] = self.async_alias
283
+ self._async_client = AsyncMilvusClient(**client_args)
284
+ return self._async_client
285
+
286
+ def get_vector_field(self, collection_name: str) -> str:
287
+ """
288
+ Retrieves the vector field name for a given collection from a cached instance.
289
+
290
+ Args:
291
+ collection_name (str): The name of the collection.
292
+
293
+ Returns:
294
+ str: The name of the vector field.
295
+
296
+ Raises:
297
+ ValueError: If no cached instance is found for the collection.
298
+ """
299
+ if collection_name in self._instances:
300
+ return self._instances[collection_name]._vector_field
301
+ if collection_name in self._async_instances:
302
+ return self._async_instances[collection_name]._vector_field
303
+
304
+ self.logger.warning(f"No cached instance found for collection '{collection_name}' to get vector field. Creating a temporary sync instance.")
305
+ # As a fallback, create a temporary sync instance to fetch the schema info.
306
+ # This is less efficient but ensures the method is robust.
307
+ temp_instance = self.get_vector_store(collection_name)
308
+ return temp_instance._vector_field
309
+
310
+ def get_embeddings(self, from_model_balancer: bool = False, provider: Optional[str] = "azure-openai", model_type: Optional[str] = "embedding-large") -> Embeddings:
311
+ """
312
+ Gets an embedding function, either from the model balancer or directly from settings.
313
+
314
+ Args:
315
+ from_model_balancer (bool): If True, uses the central model balancer service.
316
+ If False, creates a new instance based on 'embedder' settings.
317
+ model_type (str, optional): The type of model to get from the balancer. Defaults to "embedding-large".
318
+
319
+ Returns:
320
+ Embeddings: An instance of a LangChain embedding model.
321
+ """
322
+ if from_model_balancer:
323
+ model_balancer = get_model_balancer()
324
+ return model_balancer.get_model(provider=provider, model_type=model_type)
325
+
326
+ embedder_config = self.settings.get("embedder")
327
+ if not embedder_config:
328
+ self.logger.error("'embedder' configuration not found in settings.")
329
+ raise ValueError("'embedder' configuration not found in settings.")
330
+
331
+ provider = embedder_config.get("provider")
332
+ config = embedder_config.get("config")
333
+
334
+ if not provider or not config:
335
+ self.logger.error("Embedder 'provider' or 'config' not found in settings.")
336
+ raise ValueError("Embedder 'provider' or 'config' not found in settings.")
337
+
338
+ if provider == "azure-openai":
339
+ # Map the settings config to AzureOpenAIEmbeddings parameters.
340
+ azure_config = {
341
+ "azure_deployment": config.get("model"),
342
+ "openai_api_version": config.get("api_version"),
343
+ "api_key": config.get("api_key"),
344
+ "azure_endpoint": config.get("openai_base_url"),
345
+ "dimensions": config.get("embedding_dims"),
346
+ "chunk_size": config.get("chunk_size", 16),
347
+ "request_timeout": config.get("request_timeout", 60),
348
+ "max_retries": config.get("max_retries", 2)
349
+ }
350
+ # Filter out None values to use client defaults.
351
+ azure_config = {k: v for k, v in azure_config.items() if v is not None}
352
+
353
+ return AzureOpenAIEmbeddings(**azure_config)
354
+ else:
355
+ self.logger.error(f"Unsupported embedding provider: {provider}")
356
+ raise NotImplementedError(f"Embedding provider '{provider}' is not supported yet.")
357
+
358
+ def _check_collection_exists(self, collection_name: str) -> bool:
359
+ """
360
+ Checks if a collection exists.
361
+
362
+ Args:
363
+ collection_name (str): The name of the collection to check.
364
+
365
+ Returns:
366
+ bool: True if the collection exists, False otherwise.
367
+
368
+ Raises:
369
+ RuntimeError: If the check operation fails due to connection issues.
370
+ """
371
+ try:
372
+ client = self.get_vector_client()
373
+ return client.has_collection(collection_name)
374
+ except Exception as e:
375
+ self.logger.error(f"An error occurred while checking collection '{collection_name}': {e}")
376
+ raise RuntimeError(f"Failed to check collection '{collection_name}'.") from e
377
+
378
+ async def _acheck_collection_exists(self, collection_name: str) -> bool:
379
+ """
380
+ Asynchronously checks if a collection exists.
381
+
382
+ Args:
383
+ collection_name (str): The name of the collection to check.
384
+
385
+ Returns:
386
+ bool: True if the collection exists, False otherwise.
387
+
388
+ Raises:
389
+ RuntimeError: If the check operation fails due to connection issues.
390
+ """
391
+ try:
392
+ client = await self.aget_async_vector_client()
393
+ return await client.has_collection(collection_name)
394
+ except Exception as e:
395
+ self.logger.error(f"An error occurred while checking collection '{collection_name}': {e}")
396
+ raise RuntimeError(f"Failed to check collection '{collection_name}'.") from e
397
+
398
+ def _ensure_collection_exists(self, collection_name: str, embeddings: Embeddings, check_existence: bool = True):
399
+ """
400
+ Checks if a collection exists and creates it if it doesn't.
401
+ This operation is wrapped in a try-except block to handle potential failures
402
+ during collection creation.
403
+ """
404
+ try:
405
+ client = self.get_vector_client()
406
+ if check_existence and not client.has_collection(collection_name):
407
+ self.logger.info(f"Collection '{collection_name}' does not exist. Creating it.")
408
+
409
+ schema_milvus = SchemaMilvus(
410
+ embedding_function=embeddings,
411
+ collection_name=collection_name,
412
+ connection_args=self.connection_args,
413
+ index_params=self.index_params
414
+ )
415
+
416
+ schema_to_use = self.schema or DEFAULT_SCHEMA
417
+ if not self.schema:
418
+ self.logger.warning(f"No schema provided for VDBService. Using DEFAULT_SCHEMA for collection '{collection_name}'.")
419
+
420
+ schema_milvus.set_schema(schema_to_use)
421
+
422
+ if not schema_milvus.create_collection():
423
+ raise RuntimeError(f"SchemaMilvus failed to create collection '{collection_name}'.")
424
+ except Exception as e:
425
+ self.logger.error(f"An error occurred while ensuring collection '{collection_name}' : {e}")
426
+ raise RuntimeError(f"Failed to ensure collection '{collection_name}' .") from e
427
+
428
+ async def _aensure_collection_exists(self, collection_name: str, embeddings: Embeddings, check_existence: bool = True):
429
+ """
430
+ Asynchronously checks if a collection exists and creates it if it doesn't.
431
+ """
432
+ try:
433
+ # Call the new lazy initializer for the async client
434
+ client = await self.aget_async_vector_client()
435
+ if check_existence and not await client.has_collection(collection_name):
436
+ self.logger.info(f"Collection '{collection_name}' does not exist. Creating it.")
437
+
438
+ schema_milvus = SchemaMilvus(
439
+ embedding_function=embeddings,
440
+ collection_name=collection_name,
441
+ connection_args=self.connection_args,
442
+ index_params=self.index_params
443
+ )
444
+
445
+ #ensure using async connection alias
446
+ schema_milvus.aclient._using = self.async_alias
447
+
448
+ schema_to_use = self.schema or DEFAULT_SCHEMA
449
+ if not self.schema:
450
+ self.logger.warning(f"No schema provided for VDBService. Using DEFAULT_SCHEMA for collection '{collection_name}'.")
451
+
452
+ schema_milvus.set_schema(schema_to_use)
453
+
454
+ if not await schema_milvus.acreate_collection():
455
+ raise RuntimeError(f"SchemaMilvus failed to create collection '{collection_name}'.")
456
+ except Exception as e:
457
+ self.logger.error(f"An error occurred while ensuring collection '{collection_name}' : {e}")
458
+ raise RuntimeError(f"Failed to ensure collection '{collection_name}' .") from e
459
+
460
+ def _is_good_connection(self, vdb_instance: Milvus, collection_name: str) -> tuple[bool, bool | None]:
461
+ """
462
+ Checks if the Milvus instance has a good connection by verifying collection existence.
463
+
464
+ Args:
465
+ vdb_instance (Milvus): The cached vector store instance.
466
+ collection_name (str): The name of the collection to check.
467
+
468
+ Returns:
469
+ tuple[bool, bool | None]: A tuple of (is_connected, collection_exists).
470
+ collection_exists is None if the connection failed.
471
+ """
472
+ try:
473
+ # Use has_collection as a lightweight way to verify the connection and collection status.
474
+ # If the server is unreachable, this will raise an exception.
475
+ collection_exists = vdb_instance.client.has_collection(collection_name)
476
+ if collection_exists:
477
+ self.logger.debug(f"Connection for cached instance of '{collection_name}' is alive.")
478
+ else:
479
+ self.logger.warning(f"Collection '{collection_name}' not found for cached instance. It may have been dropped.")
480
+ return True, collection_exists
481
+ except Exception as e:
482
+ self.logger.warning(f"Connection check failed for cached instance of '{collection_name}': {e}")
483
+ return False, None
484
+
485
+ async def _ais_good_connection(self, vdb_instance: Milvus, collection_name: str) -> tuple[bool, bool | None]:
486
+ """
487
+ Asynchronously checks if the Milvus instance has a good connection.
488
+ """
489
+ try:
490
+ collection_exists = await vdb_instance.aclient.has_collection(collection_name)
491
+ if collection_exists:
492
+ self.logger.debug(f"Connection for cached instance of '{collection_name}' is alive.")
493
+ else:
494
+ self.logger.warning(f"Collection '{collection_name}' not found for cached instance. It may have been dropped.")
495
+ return True, collection_exists
496
+ except Exception as e:
497
+ self.logger.warning(f"Connection check failed for cached instance of '{collection_name}': {e}")
498
+ return False, None
499
+
500
+ def get_vector_store(self, collection_name: str, embeddings: Embeddings = None, metric_type: str = "IP") -> Milvus:
501
+ """
502
+ Gets a vector store instance, creating it if it doesn't exist for the collection.
503
+ This method validates both the embedding function and the vector store connection
504
+ before caching the instance to prevent faulty instances from being reused.
505
+
506
+ Args:
507
+ collection_name (str): The name of the collection in the vector database.
508
+ embeddings (Embeddings, optional): An embedding model instance. If None, one is created.
509
+ metric_type (str): The distance metric for the index. Defaults to "IP".
510
+
511
+ Returns:
512
+ Milvus: LangChain Milvus instance, which is compatible with both Zilliz and Milvus.
513
+ """
514
+ if not collection_name:
515
+ self.logger.error("get_vector_store called with no collection_name.")
516
+ raise ValueError("collection_name must be provided.")
517
+
518
+ # Check for a cached instance. If found, return it immediately.
519
+ if collection_name in self._instances:
520
+ self.logger.info(f"Returning existing vector store instance for collection: {collection_name}")
521
+ return self._instances[collection_name]
522
+
523
+ self.logger.info(f"Creating new vector store instance for collection: {collection_name}")
524
+ if embeddings is None:
525
+ embeddings = self.get_embeddings()
526
+
527
+ # Check collection exists before proceeding. Implicit creation is not supported.
528
+ if not self._check_collection_exists(collection_name):
529
+ self.logger.error(f"Collection '{collection_name}' does not exist. Implicit collection creation is not supported.")
530
+ raise ValueError(f"Collection '{collection_name}' does not exist. Please create the collection explicitly before use.")
531
+
532
+ # # 1. Validate the embedding function before proceeding.
533
+ # try:
534
+ # self.logger.info(f"Testing embedding function for collection '{collection_name}'...")
535
+ # embeddings.embed_query("validation_test_string")
536
+ # self.logger.info("Embedding function is valid.")
537
+ # except Exception as e:
538
+ # self.logger.error(
539
+ # f"The provided embedding function is invalid and failed with error: {e}. "
540
+ # f"Cannot create a vector store for collection '{collection_name}'."
541
+ # )
542
+ # raise RuntimeError(f"Invalid embedding function provided.") from e
543
+
544
+ # If embeddings are valid, proceed to create the Milvus instance.
545
+ index_params = self.index_params or {
546
+ "metric_type": metric_type,
547
+ "index_type": "AUTOINDEX",
548
+ "params": {}
549
+ }
550
+
551
+ vdb = self._create_milvus_instance_with_retry(
552
+ collection_name=collection_name,
553
+ embeddings=embeddings,
554
+ index_params=index_params
555
+ )
556
+
557
+ # Cache the newly created instance.
558
+ self._instances[collection_name] = vdb
559
+
560
+ return vdb
561
+
562
+ async def _get_or_create_async_client(self) -> AsyncMilvusClient:
563
+ """
564
+ Lazily initializes the AsyncMilvusClient.
565
+ Based on grpcio source, the client MUST be initialized in a thread
566
+ with a running event loop. Therefore, we initialize it directly in the
567
+ main async context. The synchronous __init__ is fast enough not to
568
+ block the event loop meaningfully.
569
+ """
570
+ if self._async_client is None:
571
+ self.logger.info("Lazily initializing AsyncMilvusClient directly in the main event loop...")
572
+ provider = self.settings.get("vector_store", {}).get("provider")
573
+ # This is a synchronous call, but it's lightweight and must run here.
574
+ self._async_client = self._initialize_async_milvus_client(provider)
575
+
576
+ return self._async_client
577
+
578
+ async def aget_vector_store(self, collection_name: str, embeddings: Embeddings = None, metric_type: str = "IP") -> Milvus:
579
+ """
580
+ Asynchronously gets a vector store instance, creating it if it doesn't exist.
581
+ This version is optimized to handle high concurrency using a lock.
582
+ """
583
+ if not collection_name:
584
+ self.logger.error("aget_vector_store called with no collection_name.")
585
+ raise ValueError("collection_name must be provided.")
586
+
587
+ lock = self._async_instance_locks[collection_name]
588
+ async with lock:
589
+ if collection_name in self._async_instances:
590
+ self.logger.info(f"Returning existing async vector store instance for collection: {collection_name} (post-lock)")
591
+ return self._async_instances[collection_name]
592
+
593
+ self.logger.info(f"Creating new async vector store instance for collection: {collection_name}")
594
+ if embeddings is None:
595
+ embeddings = self.get_embeddings()
596
+
597
+ # CRITICAL: Ensure the shared async client is initialized *under the lock*
598
+ # before any operation that might use it.
599
+ await self._get_or_create_async_client()
600
+
601
+ # Check collection exists before proceeding. Implicit creation is not supported.
602
+ if not await self._acheck_collection_exists(collection_name):
603
+ self.logger.error(f"Collection '{collection_name}' does not exist. Implicit collection creation is not supported.")
604
+ raise ValueError(f"Collection '{collection_name}' does not exist. Please create the collection explicitly before use.")
605
+
606
+ vdb = await self._acreate_milvus_instance_with_retry(
607
+ collection_name=collection_name,
608
+ embeddings=embeddings,
609
+ metric_type=metric_type
610
+ )
611
+
612
+ self.logger.info(f"Swapping to async alias for instance of collection {collection_name}")
613
+ vdb.aclient._using = self.async_alias
614
+
615
+ self._async_instances[collection_name] = vdb
616
+ return vdb
617
+
618
+ async def _acreate_milvus_instance_with_retry(
619
+ self,
620
+ embeddings: Embeddings,
621
+ collection_name: str,
622
+ metric_type: str = "IP",
623
+ ) -> Milvus:
624
+ """
625
+ Asynchronously creates a Milvus instance with retry logic, ensuring the connection
626
+ is established in the target thread.
627
+ """
628
+ retries = 3
629
+ last_exception = None
630
+
631
+ for attempt in range(retries):
632
+ try:
633
+ conn_args = self.connection_args.copy()
634
+ # Langchain's Milvus class will use the alias to find the connection.
635
+ conn_args["alias"] = self.sync_alias
636
+
637
+ def _create_instance_in_thread():
638
+ # --- START: CRITICAL FIX ---
639
+ # Manually connect within the thread before creating the Milvus instance.
640
+ # This ensures pymilvus registers the connection details for the current thread.
641
+ try:
642
+ connections.connect(**conn_args)
643
+ self.logger.info(f"Successfully connected to Milvus with alias '{self.sync_alias}' in thread.")
644
+ except Exception as e:
645
+ self.logger.error(f"Failed to manually connect in thread: {e}")
646
+ raise
647
+
648
+ # Now, creating the Milvus instance will find the existing connection via the alias.
649
+ instance = Milvus(
650
+ embedding_function=embeddings,
651
+ collection_name=collection_name,
652
+ connection_args=conn_args, # Pass args for completeness
653
+ # metric_type=metric_type, # <-- CRITICAL FIX: REMOVE THIS LINE
654
+ consistency_level="Strong",
655
+ # --- START: CRITICAL FIX ---
656
+ # Pass self.index_params to the Milvus constructor here
657
+ index_params=self.index_params,
658
+ # --- END: CRITICAL FIX ---
659
+ )
660
+ return instance
661
+ # --- END: CRITICAL FIX ---
662
+
663
+ self.logger.info(f"Attempt {attempt + 1}/{retries}: Creating Milvus instance for collection '{collection_name}' in a separate thread...")
664
+ vdb = await asyncio.to_thread(_create_instance_in_thread)
665
+ self.logger.info("Successfully created Milvus instance.")
666
+ return vdb
667
+
668
+ except Exception as e:
669
+ last_exception = e
670
+ self.logger.warning(
671
+ f"Attempt {attempt + 1}/{retries} failed to create Milvus instance: {e}. Retrying in {2 ** attempt}s..."
672
+ )
673
+ await asyncio.sleep(2 ** attempt)
674
+
675
+ raise RuntimeError(
676
+ f"Failed to create Milvus instance after {retries} retries."
677
+ ) from last_exception
678
+
679
+ def _create_milvus_instance_with_retry(self, collection_name: str, embeddings: Embeddings, index_params: dict, connection_args: Optional[dict] = None) -> Milvus:
680
+ """
681
+ Creates a Milvus instance with a retry mechanism for connection failures.
682
+ """
683
+ retries = 2
684
+ conn_args = connection_args if connection_args is not None else self.connection_args
685
+ for attempt in range(retries + 1):
686
+ try:
687
+ vdb = Milvus(
688
+ embedding_function=embeddings,
689
+ collection_name=collection_name,
690
+ connection_args=conn_args,
691
+ index_params=index_params
692
+ )
693
+ self.logger.info(f"Successfully connected to Milvus for collection '{collection_name}' on attempt {attempt + 1}.")
694
+ return vdb # Return on success
695
+ except Exception as e:
696
+ self.logger.warning(
697
+ f"Attempt {attempt + 1}/{retries + 1} to connect to Milvus for collection '{collection_name}' failed: {e}"
698
+ )
699
+ if attempt < retries:
700
+ self.logger.info("Retrying in 3 seconds...")
701
+ time.sleep(3)
702
+ else:
703
+ self.logger.error(f"Failed to connect to Milvus for collection '{collection_name}' after {retries + 1} attempts.")
704
+ raise RuntimeError(f"Could not connect to Milvus after {retries + 1} attempts.") from e
705
+
706
+ def drop_collection(self, collection_name: str) -> None:
707
+ """
708
+ Deletes a collection from the vector database and removes it from the cache.
709
+
710
+ Args:
711
+ collection_name (str): The name of the collection to drop.
712
+
713
+ Raises:
714
+ ValueError: If collection_name is not provided.
715
+ RuntimeError: If the operation fails on the database side.
716
+ """
717
+ if not collection_name:
718
+ self.logger.error("drop_collection called without a collection_name.")
719
+ raise ValueError("collection_name must be provided.")
720
+
721
+ self.logger.info(f"Attempting to drop collection: {collection_name}")
722
+
723
+ try:
724
+ client = self.get_vector_client()
725
+ client.drop_collection(collection_name=collection_name)
726
+ self.logger.info(f"Successfully dropped collection: {collection_name}")
727
+ except Exception as e:
728
+ self.logger.error(f"Failed to drop collection '{collection_name}': {e}")
729
+ raise RuntimeError(f"An error occurred while dropping collection '{collection_name}'.") from e
730
+ finally:
731
+ # Whether successful or not, remove the stale instance from the cache.
732
+ if collection_name in self._instances:
733
+ del self._instances[collection_name]
734
+ self.logger.info(f"Removed '{collection_name}' from instance cache.")
735
+
736
+ async def adrop_collection(self, collection_name: str) -> None:
737
+ """
738
+ Asynchronously deletes a collection from the vector database and removes it from the cache.
739
+
740
+ Args:
741
+ collection_name (str): The name of the collection to drop.
742
+
743
+ Raises:
744
+ ValueError: If collection_name is not provided.
745
+ RuntimeError: If the operation fails on the database side.
746
+ """
747
+ if not collection_name:
748
+ self.logger.error("adrop_collection called without a collection_name.")
749
+ raise ValueError("collection_name must be provided.")
750
+
751
+ self.logger.info(f"Attempting to drop collection asynchronously: {collection_name}")
752
+
753
+ try:
754
+ client = await self.aget_async_vector_client()
755
+ await client.drop_collection(collection_name=collection_name)
756
+ self.logger.info(f"Successfully dropped collection asynchronously: {collection_name}")
757
+ except Exception as e:
758
+ self.logger.error(f"Failed to drop collection '{collection_name}' asynchronously: {e}")
759
+ raise RuntimeError(f"An error occurred while dropping collection '{collection_name}' asynchronously.") from e
760
+ finally:
761
+ # Whether successful or not, remove the stale instance from the cache.
762
+ if collection_name in self._async_instances:
763
+ del self._async_instances[collection_name]
764
+ self.logger.info(f"Removed '{collection_name}' from instance cache.")
765
+
766
+ def delete_data_by_filter(self, collection_name: str = None, filter: str = None) -> None:
767
+ """ Delete data by filter
768
+
769
+ Args:
770
+ collection_name (str): collection_name
771
+ filter (str): filter
772
+ """
773
+ self.logger.info(f"Delete data by filter:{filter}")
774
+
775
+ try:
776
+ client=self.get_vector_client()
777
+ if collection_name is None or client is None or filter is None:
778
+ return RuntimeError(f"collection_name must be not null or check out your client to link milvus")
779
+ client.delete(collection_name=collection_name, filter=filter)
780
+ except Exception as e:
781
+ raise RuntimeError(f"delete collection data failed: {str(e)}")
782
+
783
+ async def adelete_data_by_filter(self, collection_name: str = None, filter: str = None) -> None:
784
+ """ Asynchronously delete data by filter
785
+
786
+ Args:
787
+ collection_name (str): collection_name
788
+ filter (str): filter
789
+ """
790
+ self.logger.info(f"Delete data by filter asynchronously:{filter}")
791
+
792
+ try:
793
+ client= await self.aget_async_vector_client()
794
+ if collection_name is None or client is None or filter is None:
795
+ return RuntimeError(f"collection_name must be not null or check out your client to link milvus")
796
+ await client.delete(collection_name=collection_name, filter=filter)
797
+ except Exception as e:
798
+ raise RuntimeError(f"delete collection data failed: {str(e)}")
799
+
800
+ async def aget_docs_by_ids(self, collection_names: List[str], ids: List[str], embeddings: Embeddings, output_fields: List[str] = None) -> List[Document]:
801
+ """
802
+ Retrieves documents from multiple collections by their primary key IDs using a direct query.
803
+
804
+ Args:
805
+ collection_names: A list of collection names to search within.
806
+ ids: A list of primary key IDs to retrieve.
807
+ embeddings: The embedding function instance, required to interact with the collection.
808
+ output_fields: A list of fields to return from the database. If None, defaults are used.
809
+
810
+ Returns:
811
+ A list of LangChain Document objects.
812
+ """
813
+ if not ids or not collection_names:
814
+ return []
815
+
816
+ # Default fields to retrieve if not specified. 'text' is assumed to be the page_content.
817
+ if output_fields is None:
818
+ # [FIX] Changed 'id' to 'source_id' and 'pk' to match the actual schema.
819
+ # We must request fields that actually exist in the collection.
820
+ output_fields = ["pk", "source_id", "text"]
821
+
822
+ async def _query_one_collection(collection_name: str) -> List[Document]:
823
+ """Helper coroutine to query a single collection."""
824
+ try:
825
+ # Use aget_vector_store to get a cached, async-ready instance
826
+ vector_store = await self.aget_vector_store(collection_name, embeddings)
827
+
828
+ # [FIX] The query should target the `source_id` field (VARCHAR)
829
+ # which stores the business-related UUIDs, not the `pk` field (INT64).
830
+ query_field = "source_id"
831
+
832
+ # Format IDs for a Milvus 'in' expression. The IDs are already strings (UUIDs).
833
+ formatted_ids = ", ".join([f'"{str(id_val)}"' for id_val in ids])
834
+ expr = f'{query_field} in [{formatted_ids}]'
835
+
836
+ # The langchain_milvus Milvus instance doesn't expose an async query method directly.
837
+ # We need to use the underlying async client.
838
+ async_client = await self.aget_async_vector_client()
839
+
840
+ self.logger.info(f"Querying collection '{collection_name}' asynchronously: {expr}")
841
+ # Use the underlying pymilvus async client's 'query' method.
842
+ results = await async_client.query(
843
+ collection_name=collection_name,
844
+ filter=expr,
845
+ output_fields=output_fields
846
+ )
847
+
848
+ # Convert the raw dictionary results from Milvus back into LangChain Document objects.
849
+ docs = []
850
+ for res in results:
851
+ # 'text' field is used as the main content of the Document.
852
+ page_content = res.pop('text', '')
853
+ # All other retrieved fields become part of the metadata.
854
+ docs.append(Document(page_content=page_content, metadata=res))
855
+ return docs
856
+ except Exception as e:
857
+ self.logger.error(f"Failed to retrieve documents by ID from collection '{collection_name}': {e}")
858
+ return []
859
+
860
+ # Run queries concurrently across all specified collections.
861
+ search_coroutines = [_query_one_collection(name) for name in collection_names]
862
+ list_of_doc_lists = await asyncio.gather(*search_coroutines)
863
+
864
+ # Flatten the results from all collections and ensure uniqueness using a dictionary,
865
+ # in case the same ID exists in multiple collections.
866
+ all_docs = {}
867
+ for doc_list in list_of_doc_lists:
868
+ for doc in doc_list:
869
+ # [FIX] Use 'source_id' for deduplication, as 'id' does not exist in the metadata.
870
+ # 'source_id' is the business key we are querying by.
871
+ doc_id = doc.metadata.get("source_id")
872
+ if doc_id and doc_id not in all_docs:
873
+ all_docs[doc_id] = doc
874
+
875
+ return list(all_docs.values())
876
+
877
+ @staticmethod
878
+ def delete_old_indexes(url: str = None, vdb: Milvus = None) -> (bool | None):
879
+ """ Delete old indexes of the same source_url
880
+
881
+ Args:
882
+ url (str): source url
883
+ vdb (Milvus): Milvus/Zilliz instance
884
+ """
885
+ # Logging is not performed in static method
886
+ if url is None or vdb is None:
887
+ return None
888
+
889
+ # Delete indexes of the same source_url
890
+ expr = f'source_url == "{url}" or source == "{url}"'
891
+ pks = vdb.get_pks(expr)
892
+
893
+ # Delete entities by pks
894
+ if pks is not None and len(pks) > 0 :
895
+ res = vdb.delete(pks)
896
+ return res
897
+
898
+ @staticmethod
899
+ def delete_old_indexes_by_id(source_id: str = None, vdb: Milvus = None) -> (bool | None):
900
+ """ Delete old indexes of the same source_id
901
+
902
+ Args:
903
+ source_id (str): source id
904
+ """
905
+ # Logging is not performed in static method
906
+ if source_id is None or vdb is None:
907
+ return None
908
+
909
+ # Delete indexes of the same source_id
910
+ expr = f'source_id == "{source_id}"'
911
+ pks = vdb.get_pks(expr)
912
+
913
+ # Delete entities by pks
914
+ if pks is not None and len(pks) > 0 :
915
+ res = vdb.delete(pks)
916
+ return res
917
+