crewplus 0.2.89__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- crewplus/__init__.py +10 -0
- crewplus/callbacks/__init__.py +1 -0
- crewplus/callbacks/async_langfuse_handler.py +166 -0
- crewplus/services/__init__.py +21 -0
- crewplus/services/azure_chat_model.py +145 -0
- crewplus/services/feedback.md +55 -0
- crewplus/services/feedback_manager.py +267 -0
- crewplus/services/gemini_chat_model.py +884 -0
- crewplus/services/init_services.py +57 -0
- crewplus/services/model_load_balancer.py +264 -0
- crewplus/services/schemas/feedback.py +61 -0
- crewplus/services/tracing_manager.py +182 -0
- crewplus/utils/__init__.py +4 -0
- crewplus/utils/schema_action.py +7 -0
- crewplus/utils/schema_document_updater.py +173 -0
- crewplus/utils/tracing_util.py +55 -0
- crewplus/vectorstores/milvus/__init__.py +5 -0
- crewplus/vectorstores/milvus/milvus_schema_manager.py +270 -0
- crewplus/vectorstores/milvus/schema_milvus.py +586 -0
- crewplus/vectorstores/milvus/vdb_service.py +917 -0
- crewplus-0.2.89.dist-info/METADATA +144 -0
- crewplus-0.2.89.dist-info/RECORD +29 -0
- crewplus-0.2.89.dist-info/WHEEL +4 -0
- crewplus-0.2.89.dist-info/entry_points.txt +4 -0
- crewplus-0.2.89.dist-info/licenses/LICENSE +21 -0
- docs/GeminiChatModel.md +247 -0
- docs/ModelLoadBalancer.md +134 -0
- docs/VDBService.md +238 -0
- docs/index.md +23 -0
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from .model_load_balancer import ModelLoadBalancer
|
|
6
|
+
|
|
7
|
+
model_balancer = None
|
|
8
|
+
|
|
9
|
+
def init_load_balancer(
|
|
10
|
+
config_path: Optional[str] = None,
|
|
11
|
+
logger: Optional[logging.Logger] = None
|
|
12
|
+
):
|
|
13
|
+
"""
|
|
14
|
+
Initializes the global ModelLoadBalancer instance.
|
|
15
|
+
|
|
16
|
+
This function is idempotent. If the balancer is already initialized,
|
|
17
|
+
it does nothing. It follows a safe initialization pattern where the
|
|
18
|
+
global instance is only assigned after successful configuration loading.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
config_path (Optional[str]): The path to the model configuration file.
|
|
22
|
+
If not provided, it's determined by the `MODEL_CONFIG_PATH`
|
|
23
|
+
environment variable, or defaults to "config/models_config.json".
|
|
24
|
+
logger (Optional[logging.Logger]): An optional logger instance to be
|
|
25
|
+
used by the model balancer.
|
|
26
|
+
"""
|
|
27
|
+
global model_balancer
|
|
28
|
+
if model_balancer is None:
|
|
29
|
+
# Use parameter if provided, otherwise check env var, then default
|
|
30
|
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
31
|
+
base_package_dir = os.path.dirname(os.path.dirname(current_dir))
|
|
32
|
+
default_config_path = os.path.join(base_package_dir, "_config", "models_config.json")
|
|
33
|
+
|
|
34
|
+
final_config_path = config_path or os.getenv(
|
|
35
|
+
"MODEL_CONFIG_PATH",
|
|
36
|
+
default_config_path
|
|
37
|
+
)
|
|
38
|
+
try:
|
|
39
|
+
# 1. Create a local instance first.
|
|
40
|
+
balancer = ModelLoadBalancer(
|
|
41
|
+
config_path=final_config_path,
|
|
42
|
+
logger=logger
|
|
43
|
+
)
|
|
44
|
+
# 2. Attempt to load its configuration.
|
|
45
|
+
balancer.load_config()
|
|
46
|
+
# 3. Only assign to the global variable on full success.
|
|
47
|
+
model_balancer = balancer
|
|
48
|
+
except Exception as e:
|
|
49
|
+
# If any step fails, the global model_balancer remains None,
|
|
50
|
+
# allowing for another initialization attempt later.
|
|
51
|
+
# Re-raise the exception to notify the caller of the failure.
|
|
52
|
+
raise RuntimeError(f"Failed to initialize and configure ModelLoadBalancer from {final_config_path}: {e}") from e
|
|
53
|
+
|
|
54
|
+
def get_model_balancer() -> ModelLoadBalancer:
|
|
55
|
+
if model_balancer is None:
|
|
56
|
+
raise RuntimeError("ModelLoadBalancer not initialized. Please call init_load_balancer() first.")
|
|
57
|
+
return model_balancer
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import random
|
|
3
|
+
import logging
|
|
4
|
+
import threading
|
|
5
|
+
from typing import Dict, List, Optional, Union
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from langchain_openai import ChatOpenAI, AzureOpenAIEmbeddings
|
|
8
|
+
from .gemini_chat_model import GeminiChatModel
|
|
9
|
+
from .azure_chat_model import TracedAzureChatOpenAI
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ModelLoadBalancer:
|
|
13
|
+
def __init__(self,
|
|
14
|
+
config_path: Optional[str] = "config/models_config.json",
|
|
15
|
+
config_data: Optional[Dict] = None,
|
|
16
|
+
logger: Optional[logging.Logger] = None):
|
|
17
|
+
"""
|
|
18
|
+
Initializes the ModelLoadBalancer.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
config_path: Path to the JSON configuration file.
|
|
22
|
+
config_data: A dictionary containing the model configuration.
|
|
23
|
+
logger: An optional logger instance. If not provided, a default one is created.
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
ValueError: If neither config_path nor config_data is provided.
|
|
27
|
+
"""
|
|
28
|
+
if not config_path and not config_data:
|
|
29
|
+
raise ValueError("Either 'config_path' or 'config_data' must be provided.")
|
|
30
|
+
|
|
31
|
+
self.config_path = config_path
|
|
32
|
+
self.config_data = config_data
|
|
33
|
+
self.logger = logger or logging.getLogger(__name__)
|
|
34
|
+
self.models_config: List[Dict] = []
|
|
35
|
+
self.thread_local = threading.local()
|
|
36
|
+
self._initialize_state()
|
|
37
|
+
self._config_loaded = False # Flag to check if config is loaded
|
|
38
|
+
|
|
39
|
+
def load_config(self):
|
|
40
|
+
"""Load and validate model configurations from a file path or a dictionary."""
|
|
41
|
+
self.logger.debug("Model balancer: loading configuration.")
|
|
42
|
+
try:
|
|
43
|
+
config = None
|
|
44
|
+
if self.config_data:
|
|
45
|
+
config = self.config_data
|
|
46
|
+
elif self.config_path:
|
|
47
|
+
with open(self.config_path, 'r') as f:
|
|
48
|
+
config = json.load(f)
|
|
49
|
+
else:
|
|
50
|
+
# This case is handled in __init__, but as a safeguard:
|
|
51
|
+
raise RuntimeError("No configuration source provided (path or data).")
|
|
52
|
+
|
|
53
|
+
# Validate config
|
|
54
|
+
if 'models' not in config or not isinstance(config['models'], list):
|
|
55
|
+
raise ValueError("Configuration must contain a 'models' list.")
|
|
56
|
+
|
|
57
|
+
for model in config.get('models', []):
|
|
58
|
+
if 'provider' not in model or 'type' not in model or 'id' not in model:
|
|
59
|
+
self.logger.error("Model config must contain 'id', 'provider', and 'type' fields.")
|
|
60
|
+
raise ValueError("Model config must contain 'id', 'provider', and 'type' fields.")
|
|
61
|
+
|
|
62
|
+
self.models_config = config['models']
|
|
63
|
+
|
|
64
|
+
self._config_loaded = True
|
|
65
|
+
self.logger.debug("Model balancer: configuration loaded successfully.")
|
|
66
|
+
except (FileNotFoundError, json.JSONDecodeError, ValueError) as e:
|
|
67
|
+
self._config_loaded = False
|
|
68
|
+
self.logger.error(f"Failed to load model configuration: {e}", exc_info=True)
|
|
69
|
+
raise RuntimeError(f"Failed to load model configuration: {e}")
|
|
70
|
+
|
|
71
|
+
def get_model(self, provider: str = None, model_type: str = None, deployment_name: str = None, with_metadata: bool = False, selection_strategy: str = 'random', disable_streaming: bool = False):
|
|
72
|
+
"""
|
|
73
|
+
Get a model instance.
|
|
74
|
+
|
|
75
|
+
Can fetch a model in two ways:
|
|
76
|
+
1. By its specific `deployment_name`.
|
|
77
|
+
2. By `provider` and `model_type`, which will select a model using a specified strategy.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
provider: The model provider (e.g., 'azure-openai', 'google-genai').
|
|
81
|
+
model_type: The type of model (e.g., 'inference', 'embedding', 'embedding-large').
|
|
82
|
+
deployment_name: The unique name for the model deployment.
|
|
83
|
+
with_metadata: If True, returns a tuple of (model, deployment_name).
|
|
84
|
+
selection_strategy: The selection strategy ('random', 'round_robin', or 'least_used'). Defaults to 'random'.
|
|
85
|
+
disable_streaming: If True, get a model instance with streaming disabled.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
An instantiated language model object, or a tuple if with_metadata is True.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
RuntimeError: If the model configuration has not been loaded.
|
|
92
|
+
ValueError: If the requested model cannot be found or if parameters are insufficient.
|
|
93
|
+
"""
|
|
94
|
+
if not self._config_loaded:
|
|
95
|
+
self.logger.error("Model configuration not loaded")
|
|
96
|
+
raise RuntimeError("Model configuration not loaded")
|
|
97
|
+
|
|
98
|
+
if deployment_name:
|
|
99
|
+
for model_config in self.models_config:
|
|
100
|
+
if model_config.get('deployment_name') == deployment_name:
|
|
101
|
+
model = self._get_or_create_model(model_config, disable_streaming)
|
|
102
|
+
if with_metadata:
|
|
103
|
+
return model, deployment_name
|
|
104
|
+
return model
|
|
105
|
+
|
|
106
|
+
self.logger.error(f"No model found for deployment name: {deployment_name}")
|
|
107
|
+
raise ValueError(f"No model found for deployment name: {deployment_name}")
|
|
108
|
+
|
|
109
|
+
if provider and model_type:
|
|
110
|
+
candidates = [model for model in self.models_config if model.get('provider') == provider and model.get('type') == model_type]
|
|
111
|
+
if not candidates:
|
|
112
|
+
self.logger.error(f"No models found for provider '{provider}' and type '{model_type}'")
|
|
113
|
+
raise ValueError(f"No models found for provider '{provider}' and type '{model_type}'")
|
|
114
|
+
|
|
115
|
+
if selection_strategy == 'random':
|
|
116
|
+
selected_model_config = self._random_selection(candidates)
|
|
117
|
+
elif selection_strategy == 'round_robin':
|
|
118
|
+
selected_model_config = self._round_robin_selection(candidates)
|
|
119
|
+
elif selection_strategy == 'least_used':
|
|
120
|
+
selected_model_config = self._least_used_selection(candidates)
|
|
121
|
+
else:
|
|
122
|
+
self.logger.warning(f"Unsupported selection strategy: '{selection_strategy}'. Defaulting to 'random'.")
|
|
123
|
+
selected_model_config = self._random_selection(candidates)
|
|
124
|
+
|
|
125
|
+
model = self._get_or_create_model(selected_model_config, disable_streaming)
|
|
126
|
+
if with_metadata:
|
|
127
|
+
return model, selected_model_config.get('deployment_name')
|
|
128
|
+
return model
|
|
129
|
+
|
|
130
|
+
raise ValueError("Either 'deployment_name' or both 'provider' and 'model_type' must be provided.")
|
|
131
|
+
|
|
132
|
+
def _get_thread_local_models_cache(self) -> Dict:
|
|
133
|
+
"""Gets the model cache for the current thread, creating it if it doesn't exist."""
|
|
134
|
+
if not hasattr(self.thread_local, 'models_cache'):
|
|
135
|
+
self.thread_local.models_cache = {}
|
|
136
|
+
return self.thread_local.models_cache
|
|
137
|
+
|
|
138
|
+
def _get_or_create_model(self, model_config: Dict, disable_streaming: bool = False):
|
|
139
|
+
"""
|
|
140
|
+
Gets a model instance from the thread-local cache. If it doesn't exist,
|
|
141
|
+
it instantiates, caches, and returns it.
|
|
142
|
+
"""
|
|
143
|
+
model_id = model_config['id']
|
|
144
|
+
cache_key = f"{model_id}"
|
|
145
|
+
if disable_streaming:
|
|
146
|
+
cache_key += "-non-streaming"
|
|
147
|
+
|
|
148
|
+
models_cache = self._get_thread_local_models_cache()
|
|
149
|
+
|
|
150
|
+
if cache_key not in models_cache:
|
|
151
|
+
self.logger.debug(f"Creating new model instance for id {cache_key} in thread {threading.get_ident()}")
|
|
152
|
+
models_cache[cache_key] = self._instantiate_model(model_config, disable_streaming)
|
|
153
|
+
|
|
154
|
+
return models_cache[cache_key]
|
|
155
|
+
|
|
156
|
+
def _instantiate_model(self, model_config: Dict, disable_streaming: bool = False):
|
|
157
|
+
"""Instantiate and return an LLM object based on the model configuration"""
|
|
158
|
+
provider = model_config['provider']
|
|
159
|
+
self.logger.debug(f"Model balancer: instantiating {provider} -- {model_config.get('deployment_name')}")
|
|
160
|
+
|
|
161
|
+
if provider == 'azure-openai':
|
|
162
|
+
kwargs = {
|
|
163
|
+
'azure_deployment': model_config['deployment_name'],
|
|
164
|
+
'openai_api_version': model_config['api_version'],
|
|
165
|
+
'azure_endpoint': model_config['api_base'],
|
|
166
|
+
'openai_api_key': model_config['api_key']
|
|
167
|
+
}
|
|
168
|
+
if 'temperature' in model_config:
|
|
169
|
+
kwargs['temperature'] = model_config['temperature']
|
|
170
|
+
|
|
171
|
+
# The 'disable_streaming' parameter takes precedence
|
|
172
|
+
if disable_streaming:
|
|
173
|
+
kwargs['disable_streaming'] = True
|
|
174
|
+
elif model_config.get('deployment_name') == 'o1-mini':
|
|
175
|
+
kwargs['disable_streaming'] = True
|
|
176
|
+
|
|
177
|
+
return TracedAzureChatOpenAI(**kwargs)
|
|
178
|
+
elif provider == 'openai':
|
|
179
|
+
kwargs = {
|
|
180
|
+
'openai_api_key': model_config['api_key']
|
|
181
|
+
}
|
|
182
|
+
if 'temperature' in model_config:
|
|
183
|
+
kwargs['temperature'] = model_config['temperature']
|
|
184
|
+
return ChatOpenAI(**kwargs)
|
|
185
|
+
elif provider == 'azure-openai-embeddings':
|
|
186
|
+
try:
|
|
187
|
+
emb_kwargs = dict(
|
|
188
|
+
model=model_config['model_name'],
|
|
189
|
+
azure_deployment=model_config['deployment_name'],
|
|
190
|
+
openai_api_version=model_config['api_version'],
|
|
191
|
+
api_key=model_config['api_key'],
|
|
192
|
+
azure_endpoint=model_config['api_base'],
|
|
193
|
+
chunk_size=16, request_timeout=60, max_retries=2
|
|
194
|
+
)
|
|
195
|
+
if 'dimensions' in model_config:
|
|
196
|
+
emb_kwargs['dimensions'] = model_config['dimensions']
|
|
197
|
+
return AzureOpenAIEmbeddings(**emb_kwargs)
|
|
198
|
+
except Exception as e:
|
|
199
|
+
self.logger.error(f"Failed to instantiate AzureOpenAIEmbeddings: {e}")
|
|
200
|
+
return None
|
|
201
|
+
elif provider == 'google-genai':
|
|
202
|
+
kwargs = {
|
|
203
|
+
'google_api_key': model_config['api_key'],
|
|
204
|
+
'model_name': model_config['deployment_name'] # Map deployment_name to model_name
|
|
205
|
+
}
|
|
206
|
+
if 'temperature' in model_config:
|
|
207
|
+
kwargs['temperature'] = model_config['temperature']
|
|
208
|
+
if 'max_tokens' in model_config:
|
|
209
|
+
kwargs['max_tokens'] = model_config['max_tokens']
|
|
210
|
+
if disable_streaming:
|
|
211
|
+
kwargs['disable_streaming'] = True
|
|
212
|
+
return GeminiChatModel(**kwargs)
|
|
213
|
+
elif provider == 'vertex-ai':
|
|
214
|
+
deployment_name = model_config['deployment_name']
|
|
215
|
+
|
|
216
|
+
# Handle the 'model_name@location' format for deployment_name
|
|
217
|
+
model_name_for_gemini = deployment_name.split('@')[0] if '@' in deployment_name else deployment_name
|
|
218
|
+
|
|
219
|
+
kwargs = {
|
|
220
|
+
'use_vertex_ai': True,
|
|
221
|
+
'model_name': model_name_for_gemini,
|
|
222
|
+
'project_id': model_config['project_id'],
|
|
223
|
+
'location': model_config['location'],
|
|
224
|
+
}
|
|
225
|
+
if 'service_account_file' in model_config:
|
|
226
|
+
kwargs['service_account_file'] = model_config['service_account_file']
|
|
227
|
+
if 'temperature' in model_config:
|
|
228
|
+
kwargs['temperature'] = model_config['temperature']
|
|
229
|
+
if 'max_tokens' in model_config:
|
|
230
|
+
kwargs['max_tokens'] = model_config['max_tokens']
|
|
231
|
+
if disable_streaming:
|
|
232
|
+
kwargs['disable_streaming'] = True
|
|
233
|
+
return GeminiChatModel(**kwargs)
|
|
234
|
+
else:
|
|
235
|
+
self.logger.error(f"Unsupported provider: {provider}")
|
|
236
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
237
|
+
|
|
238
|
+
def _initialize_state(self):
|
|
239
|
+
self.active_models = []
|
|
240
|
+
self.usage_counter = defaultdict(int)
|
|
241
|
+
self.current_indices = {}
|
|
242
|
+
|
|
243
|
+
def _random_selection(self, candidates: list) -> Dict:
|
|
244
|
+
"""Selects a model randomly from a list of candidates."""
|
|
245
|
+
model = random.choice(candidates)
|
|
246
|
+
self.usage_counter[model['id']] += 1
|
|
247
|
+
return model
|
|
248
|
+
|
|
249
|
+
def _round_robin_selection(self, candidates: list) -> Dict:
|
|
250
|
+
if id(candidates) not in self.current_indices:
|
|
251
|
+
self.current_indices[id(candidates)] = 0
|
|
252
|
+
idx = self.current_indices[id(candidates)]
|
|
253
|
+
model = candidates[idx]
|
|
254
|
+
self.current_indices[id(candidates)] = (idx + 1) % len(candidates)
|
|
255
|
+
self.usage_counter[model['id']] += 1
|
|
256
|
+
|
|
257
|
+
return model
|
|
258
|
+
|
|
259
|
+
def _least_used_selection(self, candidates: list) -> Dict:
|
|
260
|
+
min_usage = min(self.usage_counter[m['id']] for m in candidates)
|
|
261
|
+
least_used = [m for m in candidates if self.usage_counter[m['id']] == min_usage]
|
|
262
|
+
model = random.choice(least_used)
|
|
263
|
+
self.usage_counter[model['id']] += 1
|
|
264
|
+
return model
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import Optional, Union, Dict, Any
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
4
|
+
import datetime
|
|
5
|
+
|
|
6
|
+
class FeedbackConfig(BaseModel):
|
|
7
|
+
api_url: Optional[str] = None
|
|
8
|
+
api_key: Optional[str] = None
|
|
9
|
+
|
|
10
|
+
class FeedbackIn(BaseModel):
|
|
11
|
+
run_id: Union[UUID, str]
|
|
12
|
+
key: str
|
|
13
|
+
score: Optional[Union[float, int, bool]] = None
|
|
14
|
+
value: Optional[Union[float, int, bool, str, Dict]] = None
|
|
15
|
+
correction: Optional[Dict] = None
|
|
16
|
+
comment: Optional[str] = None
|
|
17
|
+
feedback_group_id: Optional[Union[UUID, str]] = None
|
|
18
|
+
config: Optional[FeedbackConfig] = None
|
|
19
|
+
|
|
20
|
+
model_config = ConfigDict(
|
|
21
|
+
json_schema_extra={
|
|
22
|
+
"example": {
|
|
23
|
+
"run_id": "5323d917-f337-42c2-8437-22e6fa623930",
|
|
24
|
+
"key": "accuracy",
|
|
25
|
+
"score": 0.95,
|
|
26
|
+
"value": "yes",
|
|
27
|
+
"correction": {"expected": "correct answer"},
|
|
28
|
+
"comment": "The model performed well on this task.",
|
|
29
|
+
"feedback_group_id": "123e4567-e89b-12d3-a456-426614174000",
|
|
30
|
+
"config": {
|
|
31
|
+
"api_url": "https://api.smith.langchain.com",
|
|
32
|
+
"api_key": "your_api_key_here"
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
class FeedbackUpdate(BaseModel):
|
|
39
|
+
score: Optional[Union[float, int, bool]] = None
|
|
40
|
+
value: Optional[Union[float, int, bool, str, Dict]] = None
|
|
41
|
+
correction: Optional[Dict] = None
|
|
42
|
+
comment: Optional[str] = None
|
|
43
|
+
config: Optional[FeedbackConfig] = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class FeedbackOut(BaseModel):
|
|
47
|
+
"""
|
|
48
|
+
Represents a retrieved feedback item, including its unique ID and group context.
|
|
49
|
+
This corresponds to a Langfuse Score object enriched with its parent's metadata.
|
|
50
|
+
"""
|
|
51
|
+
feedback_id: str = Field(..., description="The unique identifier for the feedback item (maps to Langfuse score ID).")
|
|
52
|
+
run_id: str = Field(..., description="The ID of the run (trace) this feedback belongs to.")
|
|
53
|
+
key: str = Field(..., description="The name of the feedback key.")
|
|
54
|
+
score: Optional[Union[float, int, bool, str]] = Field(None, description="The numerical or string score of the feedback.")
|
|
55
|
+
value: Optional[Any] = Field(None, description="The original value of the feedback, which could be a dict or other type.")
|
|
56
|
+
comment: Optional[str] = Field(None, description="The feedback comment.")
|
|
57
|
+
correction: Optional[Dict] = Field(None, description="Correction data from the parent observation.")
|
|
58
|
+
feedback_group_id: Optional[str] = Field(None, description="The ID of the group this feedback belongs to.")
|
|
59
|
+
created_at: datetime.datetime = Field(..., description="The timestamp when the feedback was created.")
|
|
60
|
+
|
|
61
|
+
model_config = ConfigDict(from_attributes=True)
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
# File: crewplus/services/tracing_manager.py
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional, List, Protocol, Dict
|
|
4
|
+
import os
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
# Langfuse imports with graceful fallback. This allows the application to run
|
|
8
|
+
# even if the langfuse library is not installed.
|
|
9
|
+
try:
|
|
10
|
+
from langfuse.langchain import CallbackHandler as LangfuseCallbackHandler
|
|
11
|
+
from ..callbacks.async_langfuse_handler import AsyncLangfuseCallbackHandler
|
|
12
|
+
from ..utils.tracing_util import get_langfuse_handler, get_async_langfuse_handler
|
|
13
|
+
LANGFUSE_AVAILABLE = True
|
|
14
|
+
except ImportError:
|
|
15
|
+
LANGFUSE_AVAILABLE = False
|
|
16
|
+
LangfuseCallbackHandler = None
|
|
17
|
+
AsyncLangfuseCallbackHandler = None
|
|
18
|
+
get_langfuse_handler = None
|
|
19
|
+
get_async_langfuse_handler = None
|
|
20
|
+
|
|
21
|
+
class TracingContext(Protocol):
|
|
22
|
+
"""
|
|
23
|
+
A protocol that defines a formal contract for a model to be "traceable."
|
|
24
|
+
|
|
25
|
+
This protocol ensures that any class using the TracingManager provides the
|
|
26
|
+
necessary attributes and methods for the manager to function correctly. By
|
|
27
|
+
using a Protocol, we leverage Python's static analysis tools (like mypy)
|
|
28
|
+
to enforce this contract, preventing runtime errors and making the system
|
|
29
|
+
more robust and self-documenting.
|
|
30
|
+
|
|
31
|
+
It allows the TracingManager to be completely decoupled from any specific
|
|
32
|
+
model implementation, promoting clean, compositional design.
|
|
33
|
+
|
|
34
|
+
A class that implements this protocol must provide:
|
|
35
|
+
- A `logger` attribute for logging.
|
|
36
|
+
- An `enable_tracing` attribute to control tracing.
|
|
37
|
+
- A `get_model_identifier` method to describe itself for logging purposes.
|
|
38
|
+
"""
|
|
39
|
+
logger: logging.Logger
|
|
40
|
+
enable_tracing: Optional[bool]
|
|
41
|
+
|
|
42
|
+
def get_model_identifier(self) -> str:
|
|
43
|
+
"""
|
|
44
|
+
Return a string that uniquely identifies the model instance for logging.
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
"GeminiChatModel (model='gemini-1.5-flash')"
|
|
48
|
+
|
|
49
|
+
Note:
|
|
50
|
+
The '...' (Ellipsis) is the standard way in a Protocol to indicate
|
|
51
|
+
that this method must be implemented by any class that conforms to
|
|
52
|
+
this protocol, but has no implementation in the protocol itself.
|
|
53
|
+
"""
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
class TracingManager:
|
|
57
|
+
"""
|
|
58
|
+
Manages the initialization and injection of tracing handlers for chat models.
|
|
59
|
+
|
|
60
|
+
This class uses a composition-based approach, taking a context object that
|
|
61
|
+
fulfills the TracingContext protocol. This design is highly extensible,
|
|
62
|
+
allowing new tracing providers (e.g., Helicone, OpenTelemetry) to be added
|
|
63
|
+
with minimal, isolated changes.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(self, context: TracingContext):
|
|
67
|
+
"""
|
|
68
|
+
Args:
|
|
69
|
+
context: An object (typically a chat model instance) that conforms
|
|
70
|
+
to the TracingContext protocol.
|
|
71
|
+
"""
|
|
72
|
+
self.context = context
|
|
73
|
+
self._sync_handlers: List[Any] = []
|
|
74
|
+
self._async_handlers: List[Any] = []
|
|
75
|
+
self._initialize_handlers()
|
|
76
|
+
|
|
77
|
+
def _initialize_handlers(self):
|
|
78
|
+
"""
|
|
79
|
+
Initializes all supported tracing handlers. This is the central point
|
|
80
|
+
for adding new observability tools.
|
|
81
|
+
"""
|
|
82
|
+
self._sync_handlers = []
|
|
83
|
+
self._async_handlers = []
|
|
84
|
+
self._initialize_langfuse()
|
|
85
|
+
# To add a new handler (e.g., Helicone), you would add a call to
|
|
86
|
+
# self._initialize_helicone() here.
|
|
87
|
+
|
|
88
|
+
def _initialize_langfuse(self):
|
|
89
|
+
"""Initializes the Langfuse handler if it's available and enabled."""
|
|
90
|
+
self.context.logger.debug("Attempting to initialize Langfuse handlers.")
|
|
91
|
+
if not LANGFUSE_AVAILABLE:
|
|
92
|
+
if self.context.enable_tracing is True:
|
|
93
|
+
self.context.logger.warning("Langfuse is not installed; tracing will be disabled. Install with: pip install langfuse")
|
|
94
|
+
else:
|
|
95
|
+
self.context.logger.debug("Langfuse is not installed, skipping handler initialization.")
|
|
96
|
+
return
|
|
97
|
+
|
|
98
|
+
# Determine if Langfuse should be enabled via an explicit flag or
|
|
99
|
+
# by detecting its environment variables.
|
|
100
|
+
enable_langfuse = self.context.enable_tracing
|
|
101
|
+
if enable_langfuse is None: # Auto-detect if not explicitly set
|
|
102
|
+
langfuse_env_vars = ["LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY"]
|
|
103
|
+
enable_langfuse = any(os.getenv(var) for var in langfuse_env_vars)
|
|
104
|
+
|
|
105
|
+
if enable_langfuse:
|
|
106
|
+
try:
|
|
107
|
+
# Create both sync and async handlers. We'll pick one at runtime.
|
|
108
|
+
sync_handler = get_langfuse_handler()
|
|
109
|
+
self._sync_handlers.append(sync_handler)
|
|
110
|
+
|
|
111
|
+
if AsyncLangfuseCallbackHandler:
|
|
112
|
+
async_handler = get_async_langfuse_handler()
|
|
113
|
+
self._async_handlers.append(async_handler)
|
|
114
|
+
|
|
115
|
+
self.context.logger.info(f"Langfuse tracing enabled for {self.context.get_model_identifier()}")
|
|
116
|
+
except Exception as e:
|
|
117
|
+
self.context.logger.warning(f"Failed to initialize Langfuse: {e}", exc_info=True)
|
|
118
|
+
else:
|
|
119
|
+
self.context.logger.info("Langfuse is not enabled, skipping handler initialization.")
|
|
120
|
+
|
|
121
|
+
def add_callbacks_to_config(self, config: Optional[dict], handlers: List[Any]) -> dict:
|
|
122
|
+
"""A generic helper to add a list of handlers to a config object."""
|
|
123
|
+
if config is None:
|
|
124
|
+
config = {}
|
|
125
|
+
|
|
126
|
+
self.context.logger.debug(f"Adding callbacks to config. Have {len(handlers)} handlers to add.")
|
|
127
|
+
|
|
128
|
+
if not handlers or config.get("metadata", {}).get("tracing_disabled"):
|
|
129
|
+
self.context.logger.debug("No handlers to add or tracing is disabled for this run.")
|
|
130
|
+
return config
|
|
131
|
+
|
|
132
|
+
callbacks = config.get("callbacks")
|
|
133
|
+
|
|
134
|
+
if hasattr(callbacks, 'add_handler') and hasattr(callbacks, 'handlers'):
|
|
135
|
+
# This block is for CallbackManager instances
|
|
136
|
+
self.context.logger.debug(f"Config has a CallbackManager with {len(callbacks.handlers)} existing handlers.")
|
|
137
|
+
for handler in handlers:
|
|
138
|
+
if not any(isinstance(cb, type(handler)) for cb in callbacks.handlers):
|
|
139
|
+
callbacks.add_handler(handler, inherit=True)
|
|
140
|
+
self.context.logger.debug(f"CallbackManager now has {len(callbacks.handlers)} handlers.")
|
|
141
|
+
return config
|
|
142
|
+
|
|
143
|
+
# This block is for simple lists of callbacks
|
|
144
|
+
current_callbacks = callbacks or []
|
|
145
|
+
self.context.logger.debug(f"Config has a list with {len(current_callbacks)} existing callbacks.")
|
|
146
|
+
new_callbacks = list(current_callbacks)
|
|
147
|
+
|
|
148
|
+
for handler in handlers:
|
|
149
|
+
if not any(isinstance(cb, type(handler)) for cb in new_callbacks):
|
|
150
|
+
new_callbacks.append(handler)
|
|
151
|
+
|
|
152
|
+
if len(new_callbacks) > len(current_callbacks):
|
|
153
|
+
# Create a new dictionary with the updated callbacks list.
|
|
154
|
+
# This is a safe operation that overwrites the existing 'callbacks'
|
|
155
|
+
# key and avoids mutating the original config object.
|
|
156
|
+
return {**config, "callbacks": new_callbacks}
|
|
157
|
+
|
|
158
|
+
return config
|
|
159
|
+
|
|
160
|
+
def add_sync_callbacks_to_config(self, config: Optional[dict], handlers: Optional[List[Any]] = None) -> dict:
|
|
161
|
+
"""
|
|
162
|
+
Adds synchronous tracing handlers to the request configuration.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
config: The configuration dictionary to which callbacks will be added.
|
|
166
|
+
handlers: An optional list of handlers to add. If not provided,
|
|
167
|
+
the manager's default synchronous handlers are used.
|
|
168
|
+
"""
|
|
169
|
+
handlers_to_add = self._sync_handlers if handlers is None else handlers
|
|
170
|
+
return self.add_callbacks_to_config(config, handlers_to_add)
|
|
171
|
+
|
|
172
|
+
def add_async_callbacks_to_config(self, config: Optional[dict], handlers: Optional[List[Any]] = None) -> dict:
|
|
173
|
+
"""
|
|
174
|
+
Adds asynchronous tracing handlers to the request configuration.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
config: The configuration dictionary to which callbacks will be added.
|
|
178
|
+
handlers: An optional list of handlers to add. If not provided,
|
|
179
|
+
the manager's default asynchronous handlers are used.
|
|
180
|
+
"""
|
|
181
|
+
handlers_to_add = self._async_handlers if handlers is None else handlers
|
|
182
|
+
return self.add_callbacks_to_config(config, handlers_to_add)
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
class Action(Enum):
|
|
4
|
+
UPSERT = "upsert" # Update existing fields; if a match is found, it updates, otherwise, it inserts. Does not delete unmatched existing fields.
|
|
5
|
+
DELETE = "delete" # Clear data from fields in the schema.
|
|
6
|
+
UPDATE = "update" # Update only the matching original fields.
|
|
7
|
+
INSERT = "insert" # Insert data, clearing the original fields before inserting new values.
|