ragaai-catalyst 2.1.4.1b0__py3-none-any.whl → 2.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragaai_catalyst/__init__.py +23 -2
- ragaai_catalyst/dataset.py +462 -1
- ragaai_catalyst/evaluation.py +76 -7
- ragaai_catalyst/ragaai_catalyst.py +52 -10
- ragaai_catalyst/redteaming/__init__.py +7 -0
- ragaai_catalyst/redteaming/config/detectors.toml +13 -0
- ragaai_catalyst/redteaming/data_generator/scenario_generator.py +95 -0
- ragaai_catalyst/redteaming/data_generator/test_case_generator.py +120 -0
- ragaai_catalyst/redteaming/evaluator.py +125 -0
- ragaai_catalyst/redteaming/llm_generator.py +136 -0
- ragaai_catalyst/redteaming/llm_generator_old.py +83 -0
- ragaai_catalyst/redteaming/red_teaming.py +331 -0
- ragaai_catalyst/redteaming/requirements.txt +4 -0
- ragaai_catalyst/redteaming/tests/grok.ipynb +97 -0
- ragaai_catalyst/redteaming/tests/stereotype.ipynb +2258 -0
- ragaai_catalyst/redteaming/upload_result.py +38 -0
- ragaai_catalyst/redteaming/utils/issue_description.py +114 -0
- ragaai_catalyst/redteaming/utils/rt.png +0 -0
- ragaai_catalyst/redteaming_old.py +171 -0
- ragaai_catalyst/synthetic_data_generation.py +400 -22
- ragaai_catalyst/tracers/__init__.py +17 -1
- ragaai_catalyst/tracers/agentic_tracing/data/data_structure.py +4 -2
- ragaai_catalyst/tracers/agentic_tracing/tracers/agent_tracer.py +212 -148
- ragaai_catalyst/tracers/agentic_tracing/tracers/base.py +657 -247
- ragaai_catalyst/tracers/agentic_tracing/tracers/custom_tracer.py +50 -19
- ragaai_catalyst/tracers/agentic_tracing/tracers/llm_tracer.py +588 -177
- ragaai_catalyst/tracers/agentic_tracing/tracers/main_tracer.py +99 -100
- ragaai_catalyst/tracers/agentic_tracing/tracers/network_tracer.py +3 -3
- ragaai_catalyst/tracers/agentic_tracing/tracers/tool_tracer.py +230 -29
- ragaai_catalyst/tracers/agentic_tracing/upload/trace_uploader.py +358 -0
- ragaai_catalyst/tracers/agentic_tracing/upload/upload_agentic_traces.py +75 -20
- ragaai_catalyst/tracers/agentic_tracing/upload/upload_code.py +55 -11
- ragaai_catalyst/tracers/agentic_tracing/upload/upload_local_metric.py +74 -0
- ragaai_catalyst/tracers/agentic_tracing/upload/upload_trace_metric.py +47 -16
- ragaai_catalyst/tracers/agentic_tracing/utils/create_dataset_schema.py +4 -2
- ragaai_catalyst/tracers/agentic_tracing/utils/file_name_tracker.py +26 -3
- ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +182 -17
- ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +1233 -497
- ragaai_catalyst/tracers/agentic_tracing/utils/span_attributes.py +81 -10
- ragaai_catalyst/tracers/agentic_tracing/utils/supported_llm_provider.toml +34 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/system_monitor.py +215 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +0 -32
- ragaai_catalyst/tracers/agentic_tracing/utils/unique_decorator.py +3 -1
- ragaai_catalyst/tracers/agentic_tracing/utils/zip_list_of_unique_files.py +73 -47
- ragaai_catalyst/tracers/distributed.py +300 -0
- ragaai_catalyst/tracers/exporters/__init__.py +3 -1
- ragaai_catalyst/tracers/exporters/dynamic_trace_exporter.py +160 -0
- ragaai_catalyst/tracers/exporters/ragaai_trace_exporter.py +129 -0
- ragaai_catalyst/tracers/langchain_callback.py +809 -0
- ragaai_catalyst/tracers/llamaindex_instrumentation.py +424 -0
- ragaai_catalyst/tracers/tracer.py +301 -55
- ragaai_catalyst/tracers/upload_traces.py +24 -7
- ragaai_catalyst/tracers/utils/convert_langchain_callbacks_output.py +61 -0
- ragaai_catalyst/tracers/utils/convert_llama_instru_callback.py +69 -0
- ragaai_catalyst/tracers/utils/extraction_logic_llama_index.py +74 -0
- ragaai_catalyst/tracers/utils/langchain_tracer_extraction_logic.py +82 -0
- ragaai_catalyst/tracers/utils/model_prices_and_context_window_backup.json +9365 -0
- ragaai_catalyst/tracers/utils/trace_json_converter.py +269 -0
- {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/METADATA +367 -45
- ragaai_catalyst-2.1.5.dist-info/RECORD +97 -0
- {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/WHEEL +1 -1
- ragaai_catalyst-2.1.4.1b0.dist-info/RECORD +0 -67
- {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/LICENSE +0 -0
- {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,809 @@
|
|
1
|
+
from typing import Any, Dict, List, Optional, Union, Sequence
|
2
|
+
|
3
|
+
import attr
|
4
|
+
from langchain.callbacks.base import BaseCallbackHandler
|
5
|
+
from langchain.schema import LLMResult, AgentAction, AgentFinish, BaseMessage
|
6
|
+
from datetime import datetime
|
7
|
+
import json
|
8
|
+
import os
|
9
|
+
from uuid import UUID
|
10
|
+
from functools import wraps
|
11
|
+
import asyncio
|
12
|
+
from langchain_core.documents import Document
|
13
|
+
import logging
|
14
|
+
import tempfile
|
15
|
+
import sys
|
16
|
+
import importlib
|
17
|
+
from importlib.util import find_spec
|
18
|
+
|
19
|
+
logging.basicConfig(level=logging.INFO)
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class LangchainTracer(BaseCallbackHandler):
|
24
|
+
"""
|
25
|
+
An enhanced callback handler for LangChain that traces all actions and saves them to a JSON file.
|
26
|
+
Includes improved error handling, async support, and configuration options.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
output_path: str = tempfile.gettempdir(),
|
32
|
+
trace_all: bool = True,
|
33
|
+
save_interval: Optional[int] = None,
|
34
|
+
log_level: int = logging.INFO,
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Initialize the tracer with enhanced configuration options.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
output_path (str): Directory where trace files will be saved
|
41
|
+
trace_all (bool): Whether to trace all components or only specific ones
|
42
|
+
save_interval (Optional[int]): Interval in seconds to auto-save traces
|
43
|
+
log_level (int): Logging level for the tracer
|
44
|
+
"""
|
45
|
+
super().__init__()
|
46
|
+
self.output_path = output_path
|
47
|
+
self.trace_all = trace_all
|
48
|
+
self.save_interval = save_interval
|
49
|
+
self._active = False
|
50
|
+
self._original_inits = {}
|
51
|
+
self._original_methods = {}
|
52
|
+
self.additional_metadata = {}
|
53
|
+
self._save_task = None
|
54
|
+
self._current_query = None
|
55
|
+
self.filepath = None
|
56
|
+
self.model_names = {} # Store model names by component instance
|
57
|
+
logger.setLevel(log_level)
|
58
|
+
|
59
|
+
if not os.path.exists(output_path):
|
60
|
+
os.makedirs(output_path)
|
61
|
+
|
62
|
+
self.reset_trace()
|
63
|
+
|
64
|
+
|
65
|
+
def __enter__(self):
|
66
|
+
"""Context manager entry"""
|
67
|
+
self.start()
|
68
|
+
return self
|
69
|
+
|
70
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
71
|
+
"""Context manager exit"""
|
72
|
+
|
73
|
+
self.stop()
|
74
|
+
if exc_type:
|
75
|
+
logger.error(f"Error in context manager: {exc_val}")
|
76
|
+
return False
|
77
|
+
return True
|
78
|
+
|
79
|
+
def reset_trace(self):
|
80
|
+
"""Reset the current trace to initial state with enhanced structure"""
|
81
|
+
self.current_trace: Dict[str, Any] = {
|
82
|
+
"start_time": None,
|
83
|
+
"end_time": None,
|
84
|
+
"actions": [],
|
85
|
+
"llm_calls": [],
|
86
|
+
"chain_starts": [],
|
87
|
+
"chain_ends": [],
|
88
|
+
"agent_actions": [],
|
89
|
+
"chat_model_calls": [],
|
90
|
+
"retriever_actions": [],
|
91
|
+
"tokens": [],
|
92
|
+
"errors": [],
|
93
|
+
"query": self._current_query,
|
94
|
+
"metadata": {
|
95
|
+
"version": "2.0",
|
96
|
+
"trace_all": self.trace_all,
|
97
|
+
"save_interval": self.save_interval,
|
98
|
+
},
|
99
|
+
}
|
100
|
+
|
101
|
+
async def _periodic_save(self):
|
102
|
+
"""Periodically save traces if save_interval is set"""
|
103
|
+
while self._active and self.save_interval:
|
104
|
+
await asyncio.sleep(self.save_interval)
|
105
|
+
await self._async_save_trace()
|
106
|
+
|
107
|
+
async def _async_save_trace(self, force: bool = False):
|
108
|
+
"""Asynchronously save the current trace to a JSON file"""
|
109
|
+
if not self.current_trace["start_time"] and not force:
|
110
|
+
return
|
111
|
+
|
112
|
+
try:
|
113
|
+
self.current_trace["end_time"] = datetime.now()
|
114
|
+
|
115
|
+
# Use the query from the trace or fallback to a default
|
116
|
+
safe_query = self._current_query or "unknown"
|
117
|
+
|
118
|
+
# Sanitize the query for filename
|
119
|
+
safe_query = ''.join(c for c in safe_query if c.isalnum() or c.isspace())[:50].strip()
|
120
|
+
|
121
|
+
# Add a timestamp to ensure unique filenames
|
122
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
123
|
+
filename = f"langchain_callback_traces.json"
|
124
|
+
filepath = os.path.join(self.output_path, filename)
|
125
|
+
self.filepath = filepath
|
126
|
+
|
127
|
+
trace_to_save = self.current_trace.copy()
|
128
|
+
trace_to_save["start_time"] = str(trace_to_save["start_time"])
|
129
|
+
trace_to_save["end_time"] = str(trace_to_save["end_time"])
|
130
|
+
|
131
|
+
# Save if there are meaningful events or if force is True
|
132
|
+
if (
|
133
|
+
len(trace_to_save["llm_calls"]) > 0
|
134
|
+
or len(trace_to_save["chain_starts"]) > 0
|
135
|
+
or len(trace_to_save["chain_ends"]) > 0
|
136
|
+
or len(trace_to_save["errors"]) > 0
|
137
|
+
or force
|
138
|
+
):
|
139
|
+
async with asyncio.Lock():
|
140
|
+
with open(filepath, "w", encoding="utf-8") as f:
|
141
|
+
json.dump(trace_to_save, f, indent=2, default=str)
|
142
|
+
|
143
|
+
logger.info(f"Trace saved to: {filepath}")
|
144
|
+
|
145
|
+
# Reset the current query after saving
|
146
|
+
self._current_query = None
|
147
|
+
|
148
|
+
# Reset the trace
|
149
|
+
self.reset_trace()
|
150
|
+
|
151
|
+
except Exception as e:
|
152
|
+
logger.error(f"Error saving trace: {e}")
|
153
|
+
self.on_error(e, context="save_trace")
|
154
|
+
|
155
|
+
def _save_trace(self, force: bool = False):
|
156
|
+
"""Synchronous version of trace saving"""
|
157
|
+
if asyncio.get_event_loop().is_running():
|
158
|
+
asyncio.create_task(self._async_save_trace(force))
|
159
|
+
else:
|
160
|
+
asyncio.run(self._async_save_trace(force))
|
161
|
+
|
162
|
+
def _create_safe_wrapper(self, original_func, component_name, method_name):
|
163
|
+
"""Create a safely wrapped version of an original function with enhanced error handling"""
|
164
|
+
|
165
|
+
@wraps(original_func)
|
166
|
+
def wrapped(*args, **kwargs):
|
167
|
+
if not self._active:
|
168
|
+
return original_func(*args, **kwargs)
|
169
|
+
|
170
|
+
try:
|
171
|
+
# Deep copy kwargs to avoid modifying the original
|
172
|
+
kwargs_copy = kwargs.copy() if kwargs is not None else {}
|
173
|
+
|
174
|
+
# Handle different calling conventions
|
175
|
+
if 'callbacks' not in kwargs_copy:
|
176
|
+
kwargs_copy['callbacks'] = [self]
|
177
|
+
elif self not in kwargs_copy['callbacks']:
|
178
|
+
kwargs_copy['callbacks'].append(self)
|
179
|
+
|
180
|
+
# Store model name if available
|
181
|
+
if component_name in ["OpenAI", "ChatOpenAI_LangchainOpenAI", "ChatOpenAI_ChatModels",
|
182
|
+
"ChatVertexAI", "VertexAI", "ChatGoogleGenerativeAI", "ChatAnthropic",
|
183
|
+
"ChatLiteLLM", "ChatBedrock", "AzureChatOpenAI", "ChatAnthropicVertex"]:
|
184
|
+
instance = args[0] if args else None
|
185
|
+
model_name = kwargs.get('model_name') or kwargs.get('model') or kwargs.get('model_id')
|
186
|
+
|
187
|
+
if instance and model_name:
|
188
|
+
self.model_names[id(instance)] = model_name
|
189
|
+
|
190
|
+
# Try different method signatures
|
191
|
+
try:
|
192
|
+
# First, try calling with modified kwargs
|
193
|
+
return original_func(*args, **kwargs_copy)
|
194
|
+
except TypeError:
|
195
|
+
# If that fails, try without kwargs
|
196
|
+
try:
|
197
|
+
return original_func(*args)
|
198
|
+
except Exception as e:
|
199
|
+
# If all else fails, use original call
|
200
|
+
logger.error(f"Failed to invoke {component_name} with modified callbacks: {e}")
|
201
|
+
return original_func(*args, **kwargs)
|
202
|
+
|
203
|
+
except Exception as e:
|
204
|
+
# Log any errors that occur during the function call
|
205
|
+
logger.error(f"Error in {component_name} wrapper: {e}")
|
206
|
+
|
207
|
+
# Record the error using the tracer's error handling method
|
208
|
+
self.on_error(e, context=f"wrapper_{component_name}")
|
209
|
+
|
210
|
+
# Fallback to calling the original function without modifications
|
211
|
+
return original_func(*args, **kwargs)
|
212
|
+
|
213
|
+
@wraps(original_func)
|
214
|
+
def wrapped_invoke(*args, **kwargs):
|
215
|
+
if not self._active:
|
216
|
+
return original_func(*args, **kwargs)
|
217
|
+
|
218
|
+
try:
|
219
|
+
# Deep copy kwargs to avoid modifying the original
|
220
|
+
kwargs_copy = kwargs.copy() if kwargs is not None else {}
|
221
|
+
|
222
|
+
# Handle different calling conventions
|
223
|
+
if 'config' not in kwargs_copy:
|
224
|
+
kwargs_copy['config'] = {'callbacks': [self]}
|
225
|
+
elif 'callbacks' not in kwargs_copy['config']:
|
226
|
+
kwargs_copy['config']['callbacks'] = [self]
|
227
|
+
elif self not in kwargs_copy['config']['callbacks']:
|
228
|
+
kwargs_copy['config']['callbacks'].append(self)
|
229
|
+
|
230
|
+
# Store model name if available
|
231
|
+
if component_name in ["OpenAI", "ChatOpenAI_LangchainOpenAI", "ChatOpenAI_ChatModels",
|
232
|
+
"ChatVertexAI", "VertexAI", "ChatGoogleGenerativeAI", "ChatAnthropic",
|
233
|
+
"ChatLiteLLM", "ChatBedrock", "AzureChatOpenAI", "ChatAnthropicVertex"]:
|
234
|
+
instance = args[0] if args else None
|
235
|
+
model_name = kwargs.get('model_name') or kwargs.get('model') or kwargs.get('model_id')
|
236
|
+
|
237
|
+
if instance and model_name:
|
238
|
+
self.model_names[id(instance)] = model_name
|
239
|
+
|
240
|
+
# Try different method signatures
|
241
|
+
try:
|
242
|
+
# First, try calling with modified kwargs
|
243
|
+
return original_func(*args, **kwargs_copy)
|
244
|
+
except TypeError:
|
245
|
+
# If that fails, try without kwargs
|
246
|
+
try:
|
247
|
+
return original_func(*args)
|
248
|
+
except Exception as e:
|
249
|
+
# If all else fails, use original call
|
250
|
+
logger.error(f"Failed to invoke {component_name} with modified callbacks: {e}")
|
251
|
+
return original_func(*args, **kwargs)
|
252
|
+
|
253
|
+
except Exception as e:
|
254
|
+
# Log any errors that occur during the function call
|
255
|
+
logger.error(f"Error in {component_name} wrapper: {e}")
|
256
|
+
|
257
|
+
# Record the error using the tracer's error handling method
|
258
|
+
self.on_error(e, context=f"wrapper_{component_name}")
|
259
|
+
|
260
|
+
# Fallback to calling the original function without modifications
|
261
|
+
return original_func(*args, **kwargs)
|
262
|
+
|
263
|
+
if method_name == 'invoke':
|
264
|
+
return wrapped_invoke
|
265
|
+
return wrapped
|
266
|
+
|
267
|
+
|
268
|
+
def _monkey_patch(self):
|
269
|
+
"""Enhanced monkey-patching with comprehensive component support"""
|
270
|
+
components_to_patch = {}
|
271
|
+
|
272
|
+
try:
|
273
|
+
from langchain.llms import OpenAI
|
274
|
+
components_to_patch["OpenAI"] = (OpenAI, "__init__")
|
275
|
+
except ImportError:
|
276
|
+
logger.debug("OpenAI not available for patching")
|
277
|
+
|
278
|
+
try:
|
279
|
+
from langchain_aws import ChatBedrock
|
280
|
+
components_to_patch["ChatBedrock"] = (ChatBedrock, "__init__")
|
281
|
+
except ImportError:
|
282
|
+
logger.debug("ChatBedrock not available for patching")
|
283
|
+
|
284
|
+
try:
|
285
|
+
from langchain_google_vertexai import ChatVertexAI
|
286
|
+
components_to_patch["ChatVertexAI"] = (ChatVertexAI, "__init__")
|
287
|
+
except ImportError:
|
288
|
+
logger.debug("ChatVertexAI not available for patching")
|
289
|
+
|
290
|
+
try:
|
291
|
+
from langchain_google_vertexai import VertexAI
|
292
|
+
components_to_patch["VertexAI"] = (VertexAI, "__init__")
|
293
|
+
except ImportError:
|
294
|
+
logger.debug("VertexAI not available for patching")
|
295
|
+
|
296
|
+
try:
|
297
|
+
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
298
|
+
components_to_patch["ChatAnthropicVertex"] = (ChatAnthropicVertex, "__init__")
|
299
|
+
except ImportError:
|
300
|
+
logger.debug("ChatAnthropicVertex not available for patching")
|
301
|
+
|
302
|
+
try:
|
303
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
304
|
+
components_to_patch["ChatGoogleGenerativeAI"] = (ChatGoogleGenerativeAI, "__init__")
|
305
|
+
except ImportError:
|
306
|
+
logger.debug("ChatGoogleGenerativeAI not available for patching")
|
307
|
+
|
308
|
+
try:
|
309
|
+
from langchain_anthropic import ChatAnthropic
|
310
|
+
components_to_patch["ChatAnthropic"] = (ChatAnthropic, "__init__")
|
311
|
+
except ImportError:
|
312
|
+
logger.debug("ChatAnthropic not available for patching")
|
313
|
+
|
314
|
+
try:
|
315
|
+
from langchain_community.chat_models import ChatLiteLLM
|
316
|
+
components_to_patch["ChatLiteLLM"] = (ChatLiteLLM, "__init__")
|
317
|
+
except ImportError:
|
318
|
+
logger.debug("ChatLiteLLM not available for patching")
|
319
|
+
|
320
|
+
try:
|
321
|
+
from langchain_openai import ChatOpenAI as ChatOpenAI_LangchainOpenAI
|
322
|
+
components_to_patch["ChatOpenAI_LangchainOpenAI"] = (ChatOpenAI_LangchainOpenAI, "__init__")
|
323
|
+
except ImportError:
|
324
|
+
logger.debug("ChatOpenAI (from langchain_openai) not available for patching")
|
325
|
+
|
326
|
+
try:
|
327
|
+
from langchain_openai import AzureChatOpenAI
|
328
|
+
components_to_patch["AzureChatOpenAI"] = (AzureChatOpenAI, "__init__")
|
329
|
+
except ImportError:
|
330
|
+
logger.debug("AzureChatOpenAI (from langchain_openai) not available for patching")
|
331
|
+
|
332
|
+
try:
|
333
|
+
from langchain.chat_models import ChatOpenAI as ChatOpenAI_ChatModels
|
334
|
+
components_to_patch["ChatOpenAI_ChatModels"] = (ChatOpenAI_ChatModels, "__init__")
|
335
|
+
except ImportError:
|
336
|
+
logger.debug("ChatOpenAI (from langchain.chat_models) not available for patching")
|
337
|
+
|
338
|
+
try:
|
339
|
+
from langchain.chains import create_retrieval_chain, RetrievalQA
|
340
|
+
components_to_patch["RetrievalQA"] = (RetrievalQA, "from_chain_type")
|
341
|
+
components_to_patch["create_retrieval_chain"] = (create_retrieval_chain, None)
|
342
|
+
components_to_patch['RetrievalQA.invoke'] = (RetrievalQA, 'invoke')
|
343
|
+
except ImportError:
|
344
|
+
logger.debug("Langchain chains not available for patching")
|
345
|
+
|
346
|
+
for name, (component, method_name) in components_to_patch.items():
|
347
|
+
try:
|
348
|
+
if method_name == "__init__":
|
349
|
+
original = component.__init__
|
350
|
+
self._original_inits[name] = original
|
351
|
+
component.__init__ = self._create_safe_wrapper(original, name, method_name)
|
352
|
+
elif method_name:
|
353
|
+
original = getattr(component, method_name)
|
354
|
+
self._original_methods[name] = original
|
355
|
+
if isinstance(original, classmethod):
|
356
|
+
wrapped = classmethod(
|
357
|
+
self._create_safe_wrapper(original.__func__, name, method_name)
|
358
|
+
)
|
359
|
+
else:
|
360
|
+
wrapped = self._create_safe_wrapper(original, name, method_name)
|
361
|
+
setattr(component, method_name, wrapped)
|
362
|
+
else:
|
363
|
+
self._original_methods[name] = component
|
364
|
+
globals()[name] = self._create_safe_wrapper(component, name, method_name)
|
365
|
+
except Exception as e:
|
366
|
+
logger.error(f"Error patching {name}: {e}")
|
367
|
+
self.on_error(e, context=f"patch_{name}")
|
368
|
+
|
369
|
+
def _restore_original_methods(self):
|
370
|
+
"""Restore all original methods and functions with enhanced error handling"""
|
371
|
+
# Dynamically import only what we need based on what was patched
|
372
|
+
imported_components = {}
|
373
|
+
|
374
|
+
if self._original_inits or self._original_methods:
|
375
|
+
for name in list(self._original_inits.keys()) + list(self._original_methods.keys()):
|
376
|
+
try:
|
377
|
+
if name == "OpenAI":
|
378
|
+
from langchain.llms import OpenAI
|
379
|
+
imported_components[name] = OpenAI
|
380
|
+
elif name == "ChatVertexAI":
|
381
|
+
from langchain_google_vertexai import ChatVertexAI
|
382
|
+
imported_components[name] = ChatVertexAI
|
383
|
+
elif name == "VertexAI":
|
384
|
+
from langchain_google_vertexai import VertexAI
|
385
|
+
imported_components[name] = VertexAI
|
386
|
+
elif name == "ChatGoogleGenerativeAI":
|
387
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
388
|
+
imported_components[name] = ChatGoogleGenerativeAI
|
389
|
+
elif name == "ChatAnthropic":
|
390
|
+
from langchain_anthropic import ChatAnthropic
|
391
|
+
imported_components[name] = ChatAnthropic
|
392
|
+
elif name == "ChatBedrock":
|
393
|
+
from langchain_aws import ChatBedrock
|
394
|
+
imported_components[name] = ChatBedrock
|
395
|
+
elif name == "AzureChatOpenAI":
|
396
|
+
from langchain_openai import AzureChatOpenAI
|
397
|
+
imported_components[name] = AzureChatOpenAI
|
398
|
+
elif name == "ChatAnthropicVertex":
|
399
|
+
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
400
|
+
imported_components[name] = ChatAnthropicVertex
|
401
|
+
elif name == "ChatLiteLLM":
|
402
|
+
from langchain_community.chat_models import ChatLiteLLM
|
403
|
+
imported_components[name] = ChatLiteLLM
|
404
|
+
elif name == "ChatOpenAI_LangchainOpenAI":
|
405
|
+
from langchain_openai import ChatOpenAI as ChatOpenAI_LangchainOpenAI
|
406
|
+
imported_components[name] = ChatOpenAI_LangchainOpenAI
|
407
|
+
elif name == "ChatOpenAI_ChatModels":
|
408
|
+
from langchain.chat_models import ChatOpenAI as ChatOpenAI_ChatModels
|
409
|
+
imported_components[name] = ChatOpenAI_ChatModels
|
410
|
+
elif name in ["RetrievalQA", "create_retrieval_chain", 'RetrievalQA.invoke']:
|
411
|
+
from langchain.chains import create_retrieval_chain, RetrievalQA
|
412
|
+
imported_components["RetrievalQA"] = RetrievalQA
|
413
|
+
imported_components["create_retrieval_chain"] = create_retrieval_chain
|
414
|
+
except ImportError:
|
415
|
+
logger.debug(f"{name} not available for restoration")
|
416
|
+
|
417
|
+
for name, original in self._original_inits.items():
|
418
|
+
try:
|
419
|
+
if name in imported_components:
|
420
|
+
component = imported_components[name]
|
421
|
+
component.__init__ = original
|
422
|
+
except Exception as e:
|
423
|
+
logger.error(f"Error restoring {name}: {e}")
|
424
|
+
self.on_error(e, context=f"restore_{name}")
|
425
|
+
|
426
|
+
# Restore original methods and functions
|
427
|
+
for name, original in self._original_methods.items():
|
428
|
+
try:
|
429
|
+
if "." in name:
|
430
|
+
module_name, method_name = name.rsplit(".", 1)
|
431
|
+
if module_name in imported_components:
|
432
|
+
module = imported_components[module_name]
|
433
|
+
setattr(module, method_name, original)
|
434
|
+
else:
|
435
|
+
if name in imported_components:
|
436
|
+
globals()[name] = original
|
437
|
+
except Exception as e:
|
438
|
+
logger.error(f"Error restoring {name}: {e}")
|
439
|
+
self.on_error(e, context=f"restore_{name}")
|
440
|
+
|
441
|
+
def start(self):
|
442
|
+
"""Start tracing with enhanced error handling and async support"""
|
443
|
+
try:
|
444
|
+
self.reset_trace()
|
445
|
+
self.current_trace["start_time"] = datetime.now()
|
446
|
+
self._active = True
|
447
|
+
self._monkey_patch()
|
448
|
+
|
449
|
+
if self.save_interval:
|
450
|
+
loop = asyncio.get_event_loop()
|
451
|
+
self._save_task = loop.create_task(self._periodic_save())
|
452
|
+
|
453
|
+
logger.info("Tracing started")
|
454
|
+
except Exception as e:
|
455
|
+
logger.error(f"Error starting tracer: {e}")
|
456
|
+
self.on_error(e, context="start")
|
457
|
+
raise
|
458
|
+
|
459
|
+
def stop(self):
|
460
|
+
"""Stop tracing with enhanced cleanup"""
|
461
|
+
try:
|
462
|
+
self._active = False
|
463
|
+
if self._save_task:
|
464
|
+
self._save_task.cancel()
|
465
|
+
self._restore_original_methods()
|
466
|
+
# self._save_trace(force=True)
|
467
|
+
|
468
|
+
return self.current_trace.copy(), self.additional_metadata
|
469
|
+
|
470
|
+
logger.info("Tracing stopped")
|
471
|
+
except Exception as e:
|
472
|
+
logger.error(f"Error stopping tracer: {e}")
|
473
|
+
self.on_error(e, context="stop")
|
474
|
+
raise
|
475
|
+
finally:
|
476
|
+
self._original_inits.clear()
|
477
|
+
self._original_methods.clear()
|
478
|
+
|
479
|
+
def force_save(self):
|
480
|
+
"""Force save the current trace"""
|
481
|
+
self._save_trace(force=True)
|
482
|
+
|
483
|
+
# Callback methods with enhanced error handling and logging
|
484
|
+
def on_llm_start(
|
485
|
+
self,
|
486
|
+
serialized: Dict[str, Any],
|
487
|
+
prompts: List[str],
|
488
|
+
run_id: UUID,
|
489
|
+
**kwargs: Any,
|
490
|
+
) -> None:
|
491
|
+
try:
|
492
|
+
if not self.current_trace["start_time"]:
|
493
|
+
self.current_trace["start_time"] = datetime.now()
|
494
|
+
|
495
|
+
self.current_trace["llm_calls"].append(
|
496
|
+
{
|
497
|
+
"timestamp": datetime.now(),
|
498
|
+
"event": "llm_start",
|
499
|
+
"serialized": serialized,
|
500
|
+
"prompts": prompts,
|
501
|
+
"run_id": str(run_id),
|
502
|
+
"additional_kwargs": kwargs,
|
503
|
+
}
|
504
|
+
)
|
505
|
+
except Exception as e:
|
506
|
+
self.on_error(e, context="llm_start")
|
507
|
+
|
508
|
+
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
|
509
|
+
try:
|
510
|
+
self.current_trace["llm_calls"].append(
|
511
|
+
{
|
512
|
+
"timestamp": datetime.now(),
|
513
|
+
"event": "llm_end",
|
514
|
+
"response": response.dict(),
|
515
|
+
"run_id": str(run_id),
|
516
|
+
"additional_kwargs": kwargs,
|
517
|
+
}
|
518
|
+
)
|
519
|
+
|
520
|
+
# Calculate latency
|
521
|
+
end_time = datetime.now()
|
522
|
+
latency = (end_time - self.current_trace["start_time"]).total_seconds()
|
523
|
+
|
524
|
+
# Check if values are there in llm_output
|
525
|
+
model = ""
|
526
|
+
prompt_tokens = 0
|
527
|
+
completion_tokens = 0
|
528
|
+
total_tokens = 0
|
529
|
+
|
530
|
+
# Try to get model name from llm_output first
|
531
|
+
if response and response.llm_output:
|
532
|
+
try:
|
533
|
+
model = response.llm_output.get("model_name")
|
534
|
+
if not model:
|
535
|
+
model = response.llm_output.get("model", "")
|
536
|
+
except Exception as e:
|
537
|
+
# logger.debug(f"Error getting model name: {e}")
|
538
|
+
model = ""
|
539
|
+
|
540
|
+
# Add model name
|
541
|
+
if not model:
|
542
|
+
try:
|
543
|
+
model = response.llm_output.get("model_name")
|
544
|
+
if not model:
|
545
|
+
model = response.llm_output.get("model", "")
|
546
|
+
except Exception as e:
|
547
|
+
# logger.debug(f"Error getting model name: {e}")
|
548
|
+
model = ""
|
549
|
+
|
550
|
+
|
551
|
+
# Add token usage
|
552
|
+
try:
|
553
|
+
token_usage = response.llm_output.get("token_usage", {})
|
554
|
+
if token_usage=={}:
|
555
|
+
try:
|
556
|
+
token_usage = response.llm_output.get("usage")
|
557
|
+
except Exception as e:
|
558
|
+
# logger.debug(f"Error getting token usage: {e}")
|
559
|
+
token_usage = {}
|
560
|
+
|
561
|
+
if token_usage !={}:
|
562
|
+
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
563
|
+
if prompt_tokens==0:
|
564
|
+
prompt_tokens = token_usage.get("input_tokens", 0)
|
565
|
+
completion_tokens = token_usage.get("completion_tokens", 0)
|
566
|
+
if completion_tokens==0:
|
567
|
+
completion_tokens = token_usage.get("output_tokens", 0)
|
568
|
+
|
569
|
+
total_tokens = prompt_tokens + completion_tokens
|
570
|
+
except Exception as e:
|
571
|
+
# logger.debug(f"Error getting token usage: {e}")
|
572
|
+
prompt_tokens = 0
|
573
|
+
completion_tokens = 0
|
574
|
+
total_tokens = 0
|
575
|
+
|
576
|
+
# Check if values are there in
|
577
|
+
if prompt_tokens == 0 and completion_tokens == 0:
|
578
|
+
try:
|
579
|
+
usage_data = response.generations[0][0].message.usage_metadata
|
580
|
+
prompt_tokens = usage_data.get("input_tokens", 0)
|
581
|
+
completion_tokens = usage_data.get("output_tokens", 0)
|
582
|
+
total_tokens = prompt_tokens + completion_tokens
|
583
|
+
except Exception as e:
|
584
|
+
# logger.debug(f"Error getting usage data: {e}")
|
585
|
+
try:
|
586
|
+
usage_data = response.generations[0][0].generation_info['usage_metadata']
|
587
|
+
prompt_tokens = usage_data.get("prompt_token_count", 0)
|
588
|
+
completion_tokens = usage_data.get("candidates_token_count", 0)
|
589
|
+
total_tokens = prompt_tokens + completion_tokens
|
590
|
+
except Exception as e:
|
591
|
+
# logger.debug(f"Error getting token usage: {e}")
|
592
|
+
prompt_tokens = 0
|
593
|
+
completion_tokens = 0
|
594
|
+
total_tokens = 0
|
595
|
+
|
596
|
+
# If no model name in llm_output, try to get it from stored model names
|
597
|
+
try:
|
598
|
+
if model == "":
|
599
|
+
model = list(self.model_names.values())[0]
|
600
|
+
except Exception as e:
|
601
|
+
model=""
|
602
|
+
|
603
|
+
self.additional_metadata = {
|
604
|
+
'latency': latency,
|
605
|
+
'model_name': model,
|
606
|
+
'tokens': {
|
607
|
+
'prompt': prompt_tokens,
|
608
|
+
'completion': completion_tokens,
|
609
|
+
'total': total_tokens
|
610
|
+
}
|
611
|
+
}
|
612
|
+
|
613
|
+
except Exception as e:
|
614
|
+
self.on_error(e, context="llm_end")
|
615
|
+
|
616
|
+
def on_chat_model_start(
|
617
|
+
self,
|
618
|
+
serialized: Dict[str, Any],
|
619
|
+
messages: List[List[BaseMessage]],
|
620
|
+
*,
|
621
|
+
run_id: UUID,
|
622
|
+
**kwargs: Any,
|
623
|
+
) -> None:
|
624
|
+
try:
|
625
|
+
messages_dict = [
|
626
|
+
[
|
627
|
+
{
|
628
|
+
"type": msg.type,
|
629
|
+
"content": msg.content,
|
630
|
+
"additional_kwargs": msg.additional_kwargs,
|
631
|
+
}
|
632
|
+
for msg in batch
|
633
|
+
]
|
634
|
+
for batch in messages
|
635
|
+
]
|
636
|
+
|
637
|
+
self.current_trace["chat_model_calls"].append(
|
638
|
+
{
|
639
|
+
"timestamp": datetime.now(),
|
640
|
+
"event": "chat_model_start",
|
641
|
+
"serialized": serialized,
|
642
|
+
"messages": messages_dict,
|
643
|
+
"run_id": str(run_id),
|
644
|
+
"additional_kwargs": kwargs,
|
645
|
+
}
|
646
|
+
)
|
647
|
+
except Exception as e:
|
648
|
+
self.on_error(e, context="chat_model_start")
|
649
|
+
|
650
|
+
def on_chain_start(
|
651
|
+
self,
|
652
|
+
serialized: Dict[str, Any],
|
653
|
+
inputs: Dict[str, Any],
|
654
|
+
*,
|
655
|
+
run_id: UUID,
|
656
|
+
**kwargs: Any,
|
657
|
+
) -> None:
|
658
|
+
try:
|
659
|
+
context = ""
|
660
|
+
query = ""
|
661
|
+
if isinstance(inputs, dict):
|
662
|
+
if "context" in inputs:
|
663
|
+
if isinstance(inputs["context"], Document):
|
664
|
+
context = inputs["context"].page_content
|
665
|
+
elif isinstance(inputs["context"], list):
|
666
|
+
context = "\n".join(
|
667
|
+
doc.page_content if isinstance(doc, Document) else str(doc)
|
668
|
+
for doc in inputs["context"]
|
669
|
+
)
|
670
|
+
elif isinstance(inputs["context"], str):
|
671
|
+
context = inputs["context"]
|
672
|
+
|
673
|
+
query = inputs.get("question", inputs.get("input", ""))
|
674
|
+
|
675
|
+
# Set the current query
|
676
|
+
self._current_query = query
|
677
|
+
|
678
|
+
chain_event = {
|
679
|
+
"timestamp": datetime.now(),
|
680
|
+
"serialized": serialized,
|
681
|
+
"context": context,
|
682
|
+
"query": inputs.get("question", inputs.get("input", "")),
|
683
|
+
"run_id": str(run_id),
|
684
|
+
"additional_kwargs": kwargs,
|
685
|
+
}
|
686
|
+
|
687
|
+
self.current_trace["chain_starts"].append(chain_event)
|
688
|
+
except Exception as e:
|
689
|
+
self.on_error(e, context="chain_start")
|
690
|
+
|
691
|
+
def on_chain_end(
|
692
|
+
self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any
|
693
|
+
) -> None:
|
694
|
+
try:
|
695
|
+
self.current_trace["chain_ends"].append(
|
696
|
+
{
|
697
|
+
"timestamp": datetime.now(),
|
698
|
+
"outputs": outputs,
|
699
|
+
"run_id": str(run_id),
|
700
|
+
"additional_kwargs": kwargs,
|
701
|
+
}
|
702
|
+
)
|
703
|
+
except Exception as e:
|
704
|
+
self.on_error(e, context="chain_end")
|
705
|
+
|
706
|
+
def on_agent_action(self, action: AgentAction, run_id: UUID, **kwargs: Any) -> None:
|
707
|
+
try:
|
708
|
+
self.current_trace["agent_actions"].append(
|
709
|
+
{
|
710
|
+
"timestamp": datetime.now(),
|
711
|
+
"action": action.dict(),
|
712
|
+
"run_id": str(run_id),
|
713
|
+
"additional_kwargs": kwargs,
|
714
|
+
}
|
715
|
+
)
|
716
|
+
except Exception as e:
|
717
|
+
self.on_error(e, context="agent_action")
|
718
|
+
|
719
|
+
def on_agent_finish(self, finish: AgentFinish, run_id: UUID, **kwargs: Any) -> None:
|
720
|
+
try:
|
721
|
+
self.current_trace["agent_actions"].append(
|
722
|
+
{
|
723
|
+
"timestamp": datetime.now(),
|
724
|
+
"event": "agent_finish",
|
725
|
+
"finish": finish.dict(),
|
726
|
+
"run_id": str(run_id),
|
727
|
+
"additional_kwargs": kwargs,
|
728
|
+
}
|
729
|
+
)
|
730
|
+
except Exception as e:
|
731
|
+
self.on_error(e, context="agent_finish")
|
732
|
+
|
733
|
+
def on_retriever_start(
|
734
|
+
self, serialized: Dict[str, Any], query: str, *, run_id: UUID, **kwargs: Any
|
735
|
+
) -> None:
|
736
|
+
try:
|
737
|
+
retriever_event = {
|
738
|
+
"timestamp": datetime.now(),
|
739
|
+
"event": "retriever_start",
|
740
|
+
"serialized": serialized,
|
741
|
+
"query": query,
|
742
|
+
"run_id": str(run_id),
|
743
|
+
"additional_kwargs": kwargs,
|
744
|
+
}
|
745
|
+
|
746
|
+
self.current_trace["retriever_actions"].append(retriever_event)
|
747
|
+
except Exception as e:
|
748
|
+
self.on_error(e, context="retriever_start")
|
749
|
+
|
750
|
+
def on_retriever_end(
|
751
|
+
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
752
|
+
) -> None:
|
753
|
+
try:
|
754
|
+
processed_documents = [
|
755
|
+
{"page_content": doc.page_content, "metadata": doc.metadata}
|
756
|
+
for doc in documents
|
757
|
+
]
|
758
|
+
|
759
|
+
retriever_event = {
|
760
|
+
"timestamp": datetime.now(),
|
761
|
+
"event": "retriever_end",
|
762
|
+
"documents": processed_documents,
|
763
|
+
"run_id": str(run_id),
|
764
|
+
"additional_kwargs": kwargs,
|
765
|
+
}
|
766
|
+
|
767
|
+
self.current_trace["retriever_actions"].append(retriever_event)
|
768
|
+
except Exception as e:
|
769
|
+
self.on_error(e, context="retriever_end")
|
770
|
+
|
771
|
+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
772
|
+
try:
|
773
|
+
self.current_trace["tokens"].append(
|
774
|
+
{
|
775
|
+
"timestamp": datetime.now(),
|
776
|
+
"event": "new_token",
|
777
|
+
"token": token,
|
778
|
+
"additional_kwargs": kwargs,
|
779
|
+
}
|
780
|
+
)
|
781
|
+
except Exception as e:
|
782
|
+
self.on_error(e, context="llm_new_token")
|
783
|
+
|
784
|
+
def on_error(self, error: Exception, context: str = "", **kwargs: Any) -> None:
|
785
|
+
"""Enhanced error handling with context"""
|
786
|
+
try:
|
787
|
+
error_event = {
|
788
|
+
"timestamp": datetime.now(),
|
789
|
+
"error": str(error),
|
790
|
+
"error_type": type(error).__name__,
|
791
|
+
"context": context,
|
792
|
+
"additional_kwargs": kwargs,
|
793
|
+
}
|
794
|
+
self.current_trace["errors"].append(error_event)
|
795
|
+
logger.error(f"Error in {context}: {error}")
|
796
|
+
except Exception as e:
|
797
|
+
logger.critical(f"Error in error handler: {e}")
|
798
|
+
|
799
|
+
def on_chain_error(self, error: Exception, **kwargs: Any) -> None:
|
800
|
+
self.on_error(error, context="chain", **kwargs)
|
801
|
+
|
802
|
+
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
|
803
|
+
self.on_error(error, context="llm", **kwargs)
|
804
|
+
|
805
|
+
def on_tool_error(self, error: Exception, **kwargs: Any) -> None:
|
806
|
+
self.on_error(error, context="tool", **kwargs)
|
807
|
+
|
808
|
+
def on_retriever_error(self, error: Exception, **kwargs: Any) -> None:
|
809
|
+
self.on_error(error, context="retriever", **kwargs)
|