crewplus 0.2.68__py3-none-any.whl → 0.2.71__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crewplus might be problematic. Click here for more details.

@@ -9,16 +9,11 @@ import logging
9
9
  try:
10
10
  from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler
11
11
  from ..callbacks.async_langfuse_handler import AsyncLangfuseCallbackHandler
12
- # Import the new custom handlers
13
- from ..callbacks.run_id_handler import RunIdCallbackHandler, AsyncRunIdCallbackHandler
14
12
  LANGFUSE_AVAILABLE = True
15
13
  except ImportError:
16
14
  LANGFUSE_AVAILABLE = False
17
15
  LangfuseCallbackHandler = None
18
16
  AsyncLangfuseCallbackHandler = None
19
- # Define dummy classes for the custom handlers to prevent errors if langfuse is not installed
20
- class RunIdCallbackHandler: pass
21
- class AsyncRunIdCallbackHandler: pass
22
17
 
23
18
  class TracingContext(Protocol):
24
19
  """
@@ -101,24 +96,20 @@ class TracingManager:
101
96
  # by detecting its environment variables.
102
97
  enable_langfuse = self.context.enable_tracing
103
98
  if enable_langfuse is None: # Auto-detect if not explicitly set
104
- self.context.logger.debug("enable_tracing is None, auto-detecting via environment variables.")
105
99
  langfuse_env_vars = ["LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY"]
106
100
  enable_langfuse = any(os.getenv(var) for var in langfuse_env_vars)
107
- self.context.logger.debug(f"Auto-detection result for enable_langfuse: {enable_langfuse}")
108
101
 
109
102
  if enable_langfuse:
110
103
  try:
111
- self.context.logger.debug("Langfuse is enabled, creating handlers.")
112
- # Create and add both the standard and the run_id-capturing handlers.
113
- # The standard handler creates the trace, and the custom one updates it.
114
- self._sync_handlers.append(LangfuseCallbackHandler())
115
- self._sync_handlers.append(RunIdCallbackHandler())
104
+ # Create both sync and async handlers. We'll pick one at runtime.
105
+ sync_handler = LangfuseCallbackHandler()
106
+ self._sync_handlers.append(sync_handler)
116
107
 
117
108
  if AsyncLangfuseCallbackHandler:
118
- self._async_handlers.append(AsyncLangfuseCallbackHandler())
119
- self._async_handlers.append(AsyncRunIdCallbackHandler())
109
+ async_handler = AsyncLangfuseCallbackHandler()
110
+ self._async_handlers.append(async_handler)
120
111
 
121
- self.context.logger.info(f"Langfuse tracing enabled for {self.context.get_model_identifier()}. Sync handlers loaded: {len(self._sync_handlers)}")
112
+ self.context.logger.info(f"Langfuse tracing enabled for {self.context.get_model_identifier()}")
122
113
  except Exception as e:
123
114
  self.context.logger.warning(f"Failed to initialize Langfuse: {e}", exc_info=True)
124
115
  else:
@@ -5,14 +5,15 @@
5
5
  # @Last Modified time: 2025-10-09
6
6
 
7
7
  import logging
8
- from typing import List, Dict, Union, Optional
8
+ from typing import List, Dict, Union, Optional, Any
9
9
  from langchain_milvus import Milvus
10
10
  from langchain_core.embeddings import Embeddings
11
11
  from langchain_openai import AzureOpenAIEmbeddings
12
- from pymilvus import MilvusClient, AsyncMilvusClient
12
+ from pymilvus import MilvusClient, AsyncMilvusClient, connections
13
13
  import time
14
14
  import asyncio
15
15
  import uuid
16
+ from collections import defaultdict
16
17
 
17
18
  from ...services.init_services import get_model_balancer
18
19
  from .schema_milvus import SchemaMilvus, DEFAULT_SCHEMA
@@ -96,6 +97,7 @@ class VDBService(object):
96
97
  _async_client: Optional[AsyncMilvusClient] = None
97
98
  _instances: Dict[str, Milvus] = {}
98
99
  _async_instances: Dict[str, Milvus] = {}
100
+ _async_instance_locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
99
101
 
100
102
  schema: str
101
103
  embedding_function: Embeddings
@@ -168,6 +170,10 @@ class VDBService(object):
168
170
  # lazy-initialize async milvus
169
171
  # self._async_client = self._initialize_async_milvus_client(provider)
170
172
 
173
+ # Do not initialize the async client here.
174
+ # It must be lazily initialized within an async context.
175
+ self._async_client: Optional[AsyncMilvusClient] = None
176
+
171
177
  self.schema = schema
172
178
  self.index_params = self.settings.get("index_params")
173
179
 
@@ -256,12 +262,16 @@ class VDBService(object):
256
262
 
257
263
  async def aget_async_vector_client(self) -> AsyncMilvusClient:
258
264
  """
259
- Asynchronously returns the active AsyncMilvusClient instance, initializing it if necessary.
260
-
261
- Returns:
262
- AsyncMilvusClient: The initialized async client for interacting with the vector database.
265
+ Lazily initializes and returns the AsyncMilvusClient.
266
+ This ensures the client is created within the running event loop.
263
267
  """
264
- return await self._get_or_create_async_client()
268
+ if self._async_client is None:
269
+ self.logger.info("Lazily initializing AsyncMilvusClient...")
270
+ client_args = self._get_milvus_client_args(self._provider)
271
+ # Use the dedicated async alias
272
+ client_args['alias'] = self.async_alias
273
+ self._async_client = AsyncMilvusClient(**client_args)
274
+ return self._async_client
265
275
 
266
276
  def get_vector_field(self, collection_name: str) -> str:
267
277
  """
@@ -370,6 +380,7 @@ class VDBService(object):
370
380
  Asynchronously checks if a collection exists and creates it if it doesn't.
371
381
  """
372
382
  try:
383
+ # Call the new lazy initializer for the async client
373
384
  client = await self.aget_async_vector_client()
374
385
  if check_existence and not await client.has_collection(collection_name):
375
386
  self.logger.info(f"Collection '{collection_name}' does not exist. Creating it.")
@@ -498,130 +509,117 @@ class VDBService(object):
498
509
 
499
510
  async def _get_or_create_async_client(self) -> AsyncMilvusClient:
500
511
  """
501
- Lazily initializes and returns the AsyncMilvusClient.
502
- This runs the blocking constructor in a separate thread, but also creates
503
- a temporary event loop inside that thread to satisfy the client's
504
- initialization requirements.
512
+ Lazily initializes the AsyncMilvusClient.
513
+ Based on grpcio source, the client MUST be initialized in a thread
514
+ with a running event loop. Therefore, we initialize it directly in the
515
+ main async context. The synchronous __init__ is fast enough not to
516
+ block the event loop meaningfully.
505
517
  """
506
518
  if self._async_client is None:
507
- self.logger.info("Lazily initializing AsyncMilvusClient...")
508
-
509
- def _create_with_loop():
510
- # This function runs in a separate thread via asyncio.to_thread
511
- try:
512
- # Check if an event loop exists in this new thread
513
- asyncio.get_running_loop()
514
- except RuntimeError: # 'RuntimeError: There is no current event loop...'
515
- # If not, create and set a new one
516
- loop = asyncio.new_event_loop()
517
- asyncio.set_event_loop(loop)
518
-
519
- # Now, with an event loop present in this thread, initialize the client.
520
- # This is still a blocking call, but it's contained in the thread.
521
- provider = self.settings.get("vector_store", {}).get("provider")
522
- return self._initialize_async_milvus_client(provider)
523
-
524
- self._async_client = await asyncio.to_thread(_create_with_loop)
519
+ self.logger.info("Lazily initializing AsyncMilvusClient directly in the main event loop...")
520
+ provider = self.settings.get("vector_store", {}).get("provider")
521
+ # This is a synchronous call, but it's lightweight and must run here.
522
+ self._async_client = self._initialize_async_milvus_client(provider)
525
523
 
526
524
  return self._async_client
527
525
 
528
526
  async def aget_vector_store(self, collection_name: str, embeddings: Embeddings = None, metric_type: str = "IP") -> Milvus:
529
527
  """
530
528
  Asynchronously gets a vector store instance, creating it if it doesn't exist.
529
+ This version is optimized to handle high concurrency using a lock.
531
530
  """
532
531
  if not collection_name:
533
532
  self.logger.error("aget_vector_store called with no collection_name.")
534
533
  raise ValueError("collection_name must be provided.")
535
534
 
536
- # Check for a cached instance. If found, return it immediately.
537
- if collection_name in self._async_instances:
538
- self.logger.info(f"Returning existing async vector store instance for collection: {collection_name}")
539
- return self._async_instances[collection_name]
540
-
541
- self.logger.info(f"Creating new async vector store instance for collection: {collection_name}")
542
- if embeddings is None:
543
- embeddings = self.get_embeddings()
535
+ lock = self._async_instance_locks[collection_name]
536
+ async with lock:
537
+ if collection_name in self._async_instances:
538
+ self.logger.info(f"Returning existing async vector store instance for collection: {collection_name} (post-lock)")
539
+ return self._async_instances[collection_name]
544
540
 
545
- await self._aensure_collection_exists(collection_name, embeddings, check_existence=True)
541
+ self.logger.info(f"Creating new async vector store instance for collection: {collection_name}")
542
+ if embeddings is None:
543
+ embeddings = self.get_embeddings()
546
544
 
547
- # try:
548
- # self.logger.info(f"Testing embedding function for collection '{collection_name}'...")
549
- # await embeddings.aembed_query("validation_test_string")
550
- # self.logger.info("Embedding function is valid.")
551
- # except Exception as e:
552
- # self.logger.error(
553
- # f"The provided embedding function is invalid and failed with error: {e}. "
554
- # f"Cannot create a vector store for collection '{collection_name}'."
555
- # )
556
- # raise RuntimeError(f"Invalid embedding function provided.") from e
545
+ # CRITICAL: Ensure the shared async client is initialized *under the lock*
546
+ # before any operation that might use it.
547
+ await self._get_or_create_async_client()
557
548
 
558
- index_params = self.index_params or {
559
- "metric_type": metric_type,
560
- "index_type": "AUTOINDEX",
561
- "params": {}
562
- }
563
-
564
- # Create a dedicated connection_args for the async path with the correct alias
565
- async_conn_args = self.connection_args.copy()
566
- async_conn_args['alias'] = self.async_alias
567
-
568
- # For async operations, we MUST instantiate the Milvus object using the SYNCHRONOUS alias
569
- # because its __init__ method is synchronous. This is now done in a separate thread.
570
- vdb = await self._acreate_milvus_instance_with_retry(
571
- collection_name=collection_name,
572
- embeddings=embeddings,
573
- index_params=index_params,
574
- connection_args=async_conn_args # Pass the async-specific connection args
575
- )
549
+ await self._aensure_collection_exists(collection_name, embeddings, check_existence=True)
576
550
 
577
- # After successful synchronous initialization, we hot-swap the alias on the
578
- # ASYNCHRONOUS client to ensure future async operations use the correct connection.
579
- self.logger.info(f"Swapping to async alias for instance of collection {collection_name}")
580
- # DO NOT get the async client here, get it outside this function
581
- #await self._get_or_create_async_client()
582
- vdb.aclient._using = self.async_alias
551
+ vdb = await self._acreate_milvus_instance_with_retry(
552
+ collection_name=collection_name,
553
+ embeddings=embeddings,
554
+ metric_type=metric_type
555
+ )
583
556
 
584
- self._async_instances[collection_name] = vdb
557
+ self.logger.info(f"Swapping to async alias for instance of collection {collection_name}")
558
+ vdb.aclient._using = self.async_alias
585
559
 
586
- return vdb
560
+ self._async_instances[collection_name] = vdb
561
+ return vdb
587
562
 
588
- async def _acreate_milvus_instance_with_retry(self, collection_name: str, embeddings: Embeddings, index_params: dict, connection_args: Optional[dict] = None) -> Milvus:
563
+ async def _acreate_milvus_instance_with_retry(
564
+ self,
565
+ embeddings: Embeddings,
566
+ collection_name: str,
567
+ metric_type: str = "IP",
568
+ ) -> Milvus:
589
569
  """
590
- Asynchronously creates a Milvus instance with a retry mechanism, running the synchronous
591
- constructor in a separate thread to avoid blocking the event loop.
570
+ Asynchronously creates a Milvus instance with retry logic, ensuring the connection
571
+ is established in the target thread.
592
572
  """
593
- retries = 2
594
- conn_args = connection_args if connection_args is not None else self.connection_args
595
-
596
- def _create_instance():
597
- # This synchronous function will be run in a thread
598
- return Milvus(
599
- embedding_function=embeddings,
600
- collection_name=collection_name,
601
- connection_args=conn_args,
602
- index_params=index_params
603
- )
604
-
605
- self.logger.info(f"Creating Milvus instance for collection '{collection_name}' in a separate thread...")
606
- self.logger.info(f"Connection args: {conn_args}")
573
+ retries = 3
574
+ last_exception = None
607
575
 
608
- for attempt in range(retries + 1):
576
+ for attempt in range(retries):
609
577
  try:
610
- # Run the blocking constructor in a separate thread
611
- vdb = await asyncio.to_thread(_create_instance)
612
-
613
- self.logger.info(f"Successfully connected to Milvus for collection '{collection_name}' on attempt {attempt + 1}.")
614
- return vdb # Return on success
578
+ conn_args = self.connection_args.copy()
579
+ # Langchain's Milvus class will use the alias to find the connection.
580
+ conn_args["alias"] = self.sync_alias
581
+
582
+ def _create_instance_in_thread():
583
+ # --- START: CRITICAL FIX ---
584
+ # Manually connect within the thread before creating the Milvus instance.
585
+ # This ensures pymilvus registers the connection details for the current thread.
586
+ try:
587
+ connections.connect(**conn_args)
588
+ self.logger.info(f"Successfully connected to Milvus with alias '{self.sync_alias}' in thread.")
589
+ except Exception as e:
590
+ self.logger.error(f"Failed to manually connect in thread: {e}")
591
+ raise
592
+
593
+ # Now, creating the Milvus instance will find the existing connection via the alias.
594
+ instance = Milvus(
595
+ embedding_function=embeddings,
596
+ collection_name=collection_name,
597
+ connection_args=conn_args, # Pass args for completeness
598
+ # metric_type=metric_type, # <-- CRITICAL FIX: REMOVE THIS LINE
599
+ consistency_level="Strong",
600
+ # --- START: CRITICAL FIX ---
601
+ # Pass self.index_params to the Milvus constructor here
602
+ index_params=self.index_params,
603
+ # --- END: CRITICAL FIX ---
604
+ )
605
+ return instance
606
+ # --- END: CRITICAL FIX ---
607
+
608
+ self.logger.info(f"Attempt {attempt + 1}/{retries}: Creating Milvus instance for collection '{collection_name}' in a separate thread...")
609
+ vdb = await asyncio.to_thread(_create_instance_in_thread)
610
+ self.logger.info("Successfully created Milvus instance.")
611
+ return vdb
612
+
615
613
  except Exception as e:
614
+ last_exception = e
616
615
  self.logger.warning(
617
- f"Attempt {attempt + 1}/{retries + 1} to connect to Milvus for collection '{collection_name}' failed: {e}"
616
+ f"Attempt {attempt + 1}/{retries} failed to create Milvus instance: {e}. Retrying in {2 ** attempt}s..."
618
617
  )
619
- if attempt < retries:
620
- self.logger.info("Retrying in 3 seconds...")
621
- await asyncio.sleep(3) # Use async sleep
622
- else:
623
- self.logger.error(f"Failed to connect to Milvus for collection '{collection_name}' after {retries + 1} attempts.")
624
- raise RuntimeError(f"Could not connect to Milvus after {retries + 1} attempts.") from e
618
+ await asyncio.sleep(2 ** attempt)
619
+
620
+ raise RuntimeError(
621
+ f"Failed to create Milvus instance after {retries} retries."
622
+ ) from last_exception
625
623
 
626
624
  def _create_milvus_instance_with_retry(self, collection_name: str, embeddings: Embeddings, index_params: dict, connection_args: Optional[dict] = None) -> Milvus:
627
625
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: crewplus
3
- Version: 0.2.68
3
+ Version: 0.2.71
4
4
  Summary: Base services for CrewPlus AI applications
5
5
  Author-Email: Tim Liu <tim@opsmateai.com>
6
6
  License: MIT
@@ -1,26 +1,25 @@
1
- crewplus-0.2.68.dist-info/METADATA,sha256=sqbjQMZRzh7U0AajzZZ1lvgXuZzRsYDI3_pC9AKYLYg,5424
2
- crewplus-0.2.68.dist-info/WHEEL,sha256=9P2ygRxDrTJz3gsagc0Z96ukrxjr-LFBGOgv3AuKlCA,90
3
- crewplus-0.2.68.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
4
- crewplus-0.2.68.dist-info/licenses/LICENSE,sha256=2_NHSHRTKB_cTcT_GXgcenOCtIZku8j343mOgAguTfc,1087
1
+ crewplus-0.2.71.dist-info/METADATA,sha256=x45o-KC-K6zW8OIsL3eEwcS-mQ2wupMQYHVo5JANcLs,5424
2
+ crewplus-0.2.71.dist-info/WHEEL,sha256=9P2ygRxDrTJz3gsagc0Z96ukrxjr-LFBGOgv3AuKlCA,90
3
+ crewplus-0.2.71.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
4
+ crewplus-0.2.71.dist-info/licenses/LICENSE,sha256=2_NHSHRTKB_cTcT_GXgcenOCtIZku8j343mOgAguTfc,1087
5
5
  crewplus/__init__.py,sha256=m46HkZL1Y4toD619NL47Sn2Qe084WFFSFD7e6VoYKZc,284
6
6
  crewplus/callbacks/__init__.py,sha256=YG7ieeb91qEjp1zF0-inEN7mjZ7yT_D2yzdWFT8Z1Ws,63
7
7
  crewplus/callbacks/async_langfuse_handler.py,sha256=A4uFeLpvOUdc58M7sZoE65_C1V98u0QCvx5jUquM0pM,7006
8
- crewplus/callbacks/run_id_handler.py,sha256=T_YqmGmUp2_DIc01dlVY-aC43NKcevfRejAWN-lft6M,5266
9
8
  crewplus/services/__init__.py,sha256=V1CG8b2NOmRzNgQH7BPl4KVxWSYJH5vfEsW1wVErKNE,375
10
9
  crewplus/services/azure_chat_model.py,sha256=iWzJ2GQFSNmwJx-2O5_xKPSB6VVc-7T6bcfFI8_WezA,5521
11
10
  crewplus/services/gemini_chat_model.py,sha256=DYqz01H2TIHiCDQesSozVfOsMigno6QGwOtIweg7UHk,40103
12
11
  crewplus/services/init_services.py,sha256=tc1ti8Yufo2ixlJpwg8uH0KmoyQ4EqxCOe4uTEWnlRM,2413
13
12
  crewplus/services/model_load_balancer.py,sha256=Q9Gx3GrbKworU-Ytxeqp0ggHSgZ1Q6brtTk-nCl4sak,12095
14
- crewplus/services/tracing_manager.py,sha256=oyKoPAfirAvb9M6Kca6HMNQYo3mBtXCLM3jFkPYU55s,8368
13
+ crewplus/services/tracing_manager.py,sha256=_C4zYj6o_k5mDWY7vM8UeaVaXp8SqrkvYOb0Jj-y3sY,7566
15
14
  crewplus/utils/__init__.py,sha256=2Gk1n5srFJQnFfBuYTxktdtKOVZyNrFcNaZKhXk35Pw,142
16
15
  crewplus/utils/schema_action.py,sha256=GDaBoVFQD1rXqrLVSMTfXYW1xcUu7eDcHsn57XBSnIg,422
17
16
  crewplus/utils/schema_document_updater.py,sha256=frvffxn2vbi71fHFPoGb9hq7gH2azmmdq17p-Fumnvg,7322
18
17
  crewplus/vectorstores/milvus/__init__.py,sha256=OeYv2rdyG7tcREIjBJPyt2TbE54NvyeRoWMe7LwopRE,245
19
18
  crewplus/vectorstores/milvus/milvus_schema_manager.py,sha256=-QRav-hzu-XWeJ_yKUMolal_EyMUspSg-nvh5sqlrlQ,11442
20
19
  crewplus/vectorstores/milvus/schema_milvus.py,sha256=wwNpfqsKS0xeozZES40IvB0iNwUtpCall_7Hkg0dL1g,27223
21
- crewplus/vectorstores/milvus/vdb_service.py,sha256=Jo2GWzdEuDSPlADCWA7wTgDbYQ65QiYWgFGbtkG9vG8,37789
20
+ crewplus/vectorstores/milvus/vdb_service.py,sha256=U8I6IUYZK0gCe1R9rTnVezvZfEcUS9UEbKEoeJPX8kY,37528
22
21
  docs/GeminiChatModel.md,sha256=zZYyl6RmjZTUsKxxMiC9O4yV70MC4TD-IGUmWhIDBKA,8677
23
22
  docs/ModelLoadBalancer.md,sha256=aGHES1dcXPz4c7Y8kB5-vsCNJjriH2SWmjBkSGoYKiI,4398
24
23
  docs/VDBService.md,sha256=Dw286Rrf_fsi13jyD3Bo4Sy7nZ_G7tYm7d8MZ2j9hxk,9375
25
24
  docs/index.md,sha256=3tlc15uR8lzFNM5WjdoZLw0Y9o1P1gwgbEnOdIBspqc,1643
26
- crewplus-0.2.68.dist-info/RECORD,,
25
+ crewplus-0.2.71.dist-info/RECORD,,
@@ -1,131 +0,0 @@
1
- # File: crewplus/callbacks/run_id_handler.py
2
-
3
- from typing import Any, Dict, List
4
- from uuid import UUID
5
- import logging
6
-
7
- # Langfuse imports with graceful fallback
8
- from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler
9
- from .async_langfuse_handler import AsyncLangfuseCallbackHandler
10
- from langfuse import get_client
11
- from langchain_core.messages import BaseMessage
12
- #from langchain_core.outputs import LLMResult
13
- LANGFUSE_AVAILABLE = True
14
-
15
- # --- Custom Callback Handlers to capture the run_id ---
16
-
17
- class RunIdCallbackHandler(LangfuseCallbackHandler):
18
- """
19
- A custom handler that injects the LangChain run_id into the metadata
20
- before the Langfuse observation is created.
21
- """
22
- def __init__(self, *args, **kwargs):
23
- super().__init__(*args, **kwargs)
24
- self.run_trace_map = {}
25
- # Use a named logger for better context
26
- self.logger = logging.getLogger(__name__)
27
- self.logger.info("RunIdCallbackHandler initialized.")
28
-
29
- def _inject_run_id_to_metadata(self, run_id: UUID, kwargs: Dict[str, Any]) -> Dict[str, Any]:
30
- """Helper to safely add the run_id to the metadata in kwargs."""
31
- metadata = kwargs.get("metadata") or {}
32
- metadata["langchain_run_id"] = str(run_id)
33
- kwargs["metadata"] = metadata
34
- return kwargs
35
-
36
- def on_chain_start(
37
- self,
38
- serialized: Dict[str, Any],
39
- inputs: Dict[str, Any],
40
- *,
41
- run_id: UUID,
42
- **kwargs: Any,
43
- ) -> Any:
44
- self.logger.debug(f"[on_chain_start] Intercepted run_id: {run_id}")
45
- kwargs = self._inject_run_id_to_metadata(run_id, kwargs)
46
-
47
- # Call the base handler with the modified kwargs
48
- result = super().on_chain_start(serialized, inputs, run_id=run_id, **kwargs)
49
-
50
- # We still map the trace_id for easy retrieval in tests
51
- if self.last_trace_id:
52
- self.run_trace_map[str(run_id)] = self.last_trace_id
53
- self.logger.info(f"[on_chain_start] Mapped run_id '{run_id}' to trace_id '{self.last_trace_id}'.")
54
-
55
- return result
56
-
57
- def on_chat_model_start(
58
- self,
59
- serialized: Dict[str, Any],
60
- messages: List[List[BaseMessage]],
61
- *,
62
- run_id: UUID,
63
- parent_run_id: UUID | None = None,
64
- **kwargs: Any,
65
- ) -> Any:
66
- self.logger.debug(f"[on_chat_model_start] Intercepted run_id: {run_id}")
67
-
68
- # Only inject the run_id if this is the root of the trace
69
- if parent_run_id is None:
70
- kwargs = self._inject_run_id_to_metadata(run_id, kwargs)
71
-
72
- # Call the base handler with potentially modified kwargs
73
- result = super().on_chat_model_start(serialized, messages, run_id=run_id, parent_run_id=parent_run_id, **kwargs)
74
-
75
- if parent_run_id is None and self.last_trace_id:
76
- self.run_trace_map[str(run_id)] = self.last_trace_id
77
- self.logger.info(f"[on_chat_model_start] Mapped root run_id '{run_id}' to trace_id '{self.last_trace_id}'.")
78
-
79
- return result
80
-
81
- # You would similarly update the AsyncRunIdCallbackHandler if you use it
82
- class AsyncRunIdCallbackHandler(AsyncLangfuseCallbackHandler):
83
- """
84
- An async custom handler that injects the LangChain run_id into the metadata
85
- before the Langfuse observation is created.
86
- """
87
- def __init__(self, *args, **kwargs):
88
- super().__init__(*args, **kwargs)
89
- self.run_trace_map = {}
90
- self.logger = logging.getLogger(__name__)
91
- self.logger.info("AsyncRunIdCallbackHandler initialized.")
92
-
93
- def _inject_run_id_to_metadata(self, run_id: UUID, kwargs: Dict[str, Any]) -> Dict[str, Any]:
94
- """Helper to safely add the run_id to the metadata in kwargs."""
95
- metadata = kwargs.get("metadata") or {}
96
- metadata["langchain_run_id"] = str(run_id)
97
- kwargs["metadata"] = metadata
98
- return kwargs
99
-
100
- async def on_chain_start(
101
- self,
102
- serialized: Dict[str, Any],
103
- inputs: Dict[str, Any],
104
- *,
105
- run_id: UUID,
106
- **kwargs: Any,
107
- ) -> Any:
108
- self.logger.debug(f"Async [on_chain_start] Intercepted run_id: {run_id}")
109
- kwargs = self._inject_run_id_to_metadata(run_id, kwargs)
110
- result = await super().on_chain_start(serialized, inputs, run_id=run_id, **kwargs)
111
- if self.last_trace_id:
112
- self.run_trace_map[str(run_id)] = self.last_trace_id
113
- return result
114
-
115
- async def on_chat_model_start(
116
- self,
117
- serialized: Dict[str, Any],
118
- messages: List[List[BaseMessage]],
119
- *,
120
- run_id: UUID,
121
- parent_run_id: UUID | None = None,
122
- **kwargs: Any,
123
- ) -> Any:
124
- self.logger.debug(f"Async [on_chat_model_start] Intercepted run_id: {run_id}")
125
- if parent_run_id is None:
126
- kwargs = self._inject_run_id_to_metadata(run_id, kwargs)
127
- result = await super().on_chat_model_start(serialized, messages, run_id=run_id, parent_run_id=parent_run_id, **kwargs)
128
-
129
- if parent_run_id is None and self.last_trace_id:
130
- self.run_trace_map[str(run_id)] = self.last_trace_id
131
- return result