ragaai-catalyst 2.1.4.1b0__py3-none-any.whl → 2.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. ragaai_catalyst/__init__.py +23 -2
  2. ragaai_catalyst/dataset.py +462 -1
  3. ragaai_catalyst/evaluation.py +76 -7
  4. ragaai_catalyst/ragaai_catalyst.py +52 -10
  5. ragaai_catalyst/redteaming/__init__.py +7 -0
  6. ragaai_catalyst/redteaming/config/detectors.toml +13 -0
  7. ragaai_catalyst/redteaming/data_generator/scenario_generator.py +95 -0
  8. ragaai_catalyst/redteaming/data_generator/test_case_generator.py +120 -0
  9. ragaai_catalyst/redteaming/evaluator.py +125 -0
  10. ragaai_catalyst/redteaming/llm_generator.py +136 -0
  11. ragaai_catalyst/redteaming/llm_generator_old.py +83 -0
  12. ragaai_catalyst/redteaming/red_teaming.py +331 -0
  13. ragaai_catalyst/redteaming/requirements.txt +4 -0
  14. ragaai_catalyst/redteaming/tests/grok.ipynb +97 -0
  15. ragaai_catalyst/redteaming/tests/stereotype.ipynb +2258 -0
  16. ragaai_catalyst/redteaming/upload_result.py +38 -0
  17. ragaai_catalyst/redteaming/utils/issue_description.py +114 -0
  18. ragaai_catalyst/redteaming/utils/rt.png +0 -0
  19. ragaai_catalyst/redteaming_old.py +171 -0
  20. ragaai_catalyst/synthetic_data_generation.py +400 -22
  21. ragaai_catalyst/tracers/__init__.py +17 -1
  22. ragaai_catalyst/tracers/agentic_tracing/data/data_structure.py +4 -2
  23. ragaai_catalyst/tracers/agentic_tracing/tracers/agent_tracer.py +212 -148
  24. ragaai_catalyst/tracers/agentic_tracing/tracers/base.py +657 -247
  25. ragaai_catalyst/tracers/agentic_tracing/tracers/custom_tracer.py +50 -19
  26. ragaai_catalyst/tracers/agentic_tracing/tracers/llm_tracer.py +588 -177
  27. ragaai_catalyst/tracers/agentic_tracing/tracers/main_tracer.py +99 -100
  28. ragaai_catalyst/tracers/agentic_tracing/tracers/network_tracer.py +3 -3
  29. ragaai_catalyst/tracers/agentic_tracing/tracers/tool_tracer.py +230 -29
  30. ragaai_catalyst/tracers/agentic_tracing/upload/trace_uploader.py +358 -0
  31. ragaai_catalyst/tracers/agentic_tracing/upload/upload_agentic_traces.py +75 -20
  32. ragaai_catalyst/tracers/agentic_tracing/upload/upload_code.py +55 -11
  33. ragaai_catalyst/tracers/agentic_tracing/upload/upload_local_metric.py +74 -0
  34. ragaai_catalyst/tracers/agentic_tracing/upload/upload_trace_metric.py +47 -16
  35. ragaai_catalyst/tracers/agentic_tracing/utils/create_dataset_schema.py +4 -2
  36. ragaai_catalyst/tracers/agentic_tracing/utils/file_name_tracker.py +26 -3
  37. ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +182 -17
  38. ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +1233 -497
  39. ragaai_catalyst/tracers/agentic_tracing/utils/span_attributes.py +81 -10
  40. ragaai_catalyst/tracers/agentic_tracing/utils/supported_llm_provider.toml +34 -0
  41. ragaai_catalyst/tracers/agentic_tracing/utils/system_monitor.py +215 -0
  42. ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +0 -32
  43. ragaai_catalyst/tracers/agentic_tracing/utils/unique_decorator.py +3 -1
  44. ragaai_catalyst/tracers/agentic_tracing/utils/zip_list_of_unique_files.py +73 -47
  45. ragaai_catalyst/tracers/distributed.py +300 -0
  46. ragaai_catalyst/tracers/exporters/__init__.py +3 -1
  47. ragaai_catalyst/tracers/exporters/dynamic_trace_exporter.py +160 -0
  48. ragaai_catalyst/tracers/exporters/ragaai_trace_exporter.py +129 -0
  49. ragaai_catalyst/tracers/langchain_callback.py +809 -0
  50. ragaai_catalyst/tracers/llamaindex_instrumentation.py +424 -0
  51. ragaai_catalyst/tracers/tracer.py +301 -55
  52. ragaai_catalyst/tracers/upload_traces.py +24 -7
  53. ragaai_catalyst/tracers/utils/convert_langchain_callbacks_output.py +61 -0
  54. ragaai_catalyst/tracers/utils/convert_llama_instru_callback.py +69 -0
  55. ragaai_catalyst/tracers/utils/extraction_logic_llama_index.py +74 -0
  56. ragaai_catalyst/tracers/utils/langchain_tracer_extraction_logic.py +82 -0
  57. ragaai_catalyst/tracers/utils/model_prices_and_context_window_backup.json +9365 -0
  58. ragaai_catalyst/tracers/utils/trace_json_converter.py +269 -0
  59. {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/METADATA +367 -45
  60. ragaai_catalyst-2.1.5.dist-info/RECORD +97 -0
  61. {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/WHEEL +1 -1
  62. ragaai_catalyst-2.1.4.1b0.dist-info/RECORD +0 -67
  63. {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/LICENSE +0 -0
  64. {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,809 @@
1
+ from typing import Any, Dict, List, Optional, Union, Sequence
2
+
3
+ import attr
4
+ from langchain.callbacks.base import BaseCallbackHandler
5
+ from langchain.schema import LLMResult, AgentAction, AgentFinish, BaseMessage
6
+ from datetime import datetime
7
+ import json
8
+ import os
9
+ from uuid import UUID
10
+ from functools import wraps
11
+ import asyncio
12
+ from langchain_core.documents import Document
13
+ import logging
14
+ import tempfile
15
+ import sys
16
+ import importlib
17
+ from importlib.util import find_spec
18
+
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class LangchainTracer(BaseCallbackHandler):
24
+ """
25
+ An enhanced callback handler for LangChain that traces all actions and saves them to a JSON file.
26
+ Includes improved error handling, async support, and configuration options.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ output_path: str = tempfile.gettempdir(),
32
+ trace_all: bool = True,
33
+ save_interval: Optional[int] = None,
34
+ log_level: int = logging.INFO,
35
+ ):
36
+ """
37
+ Initialize the tracer with enhanced configuration options.
38
+
39
+ Args:
40
+ output_path (str): Directory where trace files will be saved
41
+ trace_all (bool): Whether to trace all components or only specific ones
42
+ save_interval (Optional[int]): Interval in seconds to auto-save traces
43
+ log_level (int): Logging level for the tracer
44
+ """
45
+ super().__init__()
46
+ self.output_path = output_path
47
+ self.trace_all = trace_all
48
+ self.save_interval = save_interval
49
+ self._active = False
50
+ self._original_inits = {}
51
+ self._original_methods = {}
52
+ self.additional_metadata = {}
53
+ self._save_task = None
54
+ self._current_query = None
55
+ self.filepath = None
56
+ self.model_names = {} # Store model names by component instance
57
+ logger.setLevel(log_level)
58
+
59
+ if not os.path.exists(output_path):
60
+ os.makedirs(output_path)
61
+
62
+ self.reset_trace()
63
+
64
+
65
+ def __enter__(self):
66
+ """Context manager entry"""
67
+ self.start()
68
+ return self
69
+
70
+ def __exit__(self, exc_type, exc_val, exc_tb):
71
+ """Context manager exit"""
72
+
73
+ self.stop()
74
+ if exc_type:
75
+ logger.error(f"Error in context manager: {exc_val}")
76
+ return False
77
+ return True
78
+
79
+ def reset_trace(self):
80
+ """Reset the current trace to initial state with enhanced structure"""
81
+ self.current_trace: Dict[str, Any] = {
82
+ "start_time": None,
83
+ "end_time": None,
84
+ "actions": [],
85
+ "llm_calls": [],
86
+ "chain_starts": [],
87
+ "chain_ends": [],
88
+ "agent_actions": [],
89
+ "chat_model_calls": [],
90
+ "retriever_actions": [],
91
+ "tokens": [],
92
+ "errors": [],
93
+ "query": self._current_query,
94
+ "metadata": {
95
+ "version": "2.0",
96
+ "trace_all": self.trace_all,
97
+ "save_interval": self.save_interval,
98
+ },
99
+ }
100
+
101
+ async def _periodic_save(self):
102
+ """Periodically save traces if save_interval is set"""
103
+ while self._active and self.save_interval:
104
+ await asyncio.sleep(self.save_interval)
105
+ await self._async_save_trace()
106
+
107
+ async def _async_save_trace(self, force: bool = False):
108
+ """Asynchronously save the current trace to a JSON file"""
109
+ if not self.current_trace["start_time"] and not force:
110
+ return
111
+
112
+ try:
113
+ self.current_trace["end_time"] = datetime.now()
114
+
115
+ # Use the query from the trace or fallback to a default
116
+ safe_query = self._current_query or "unknown"
117
+
118
+ # Sanitize the query for filename
119
+ safe_query = ''.join(c for c in safe_query if c.isalnum() or c.isspace())[:50].strip()
120
+
121
+ # Add a timestamp to ensure unique filenames
122
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
123
+ filename = f"langchain_callback_traces.json"
124
+ filepath = os.path.join(self.output_path, filename)
125
+ self.filepath = filepath
126
+
127
+ trace_to_save = self.current_trace.copy()
128
+ trace_to_save["start_time"] = str(trace_to_save["start_time"])
129
+ trace_to_save["end_time"] = str(trace_to_save["end_time"])
130
+
131
+ # Save if there are meaningful events or if force is True
132
+ if (
133
+ len(trace_to_save["llm_calls"]) > 0
134
+ or len(trace_to_save["chain_starts"]) > 0
135
+ or len(trace_to_save["chain_ends"]) > 0
136
+ or len(trace_to_save["errors"]) > 0
137
+ or force
138
+ ):
139
+ async with asyncio.Lock():
140
+ with open(filepath, "w", encoding="utf-8") as f:
141
+ json.dump(trace_to_save, f, indent=2, default=str)
142
+
143
+ logger.info(f"Trace saved to: {filepath}")
144
+
145
+ # Reset the current query after saving
146
+ self._current_query = None
147
+
148
+ # Reset the trace
149
+ self.reset_trace()
150
+
151
+ except Exception as e:
152
+ logger.error(f"Error saving trace: {e}")
153
+ self.on_error(e, context="save_trace")
154
+
155
+ def _save_trace(self, force: bool = False):
156
+ """Synchronous version of trace saving"""
157
+ if asyncio.get_event_loop().is_running():
158
+ asyncio.create_task(self._async_save_trace(force))
159
+ else:
160
+ asyncio.run(self._async_save_trace(force))
161
+
162
+ def _create_safe_wrapper(self, original_func, component_name, method_name):
163
+ """Create a safely wrapped version of an original function with enhanced error handling"""
164
+
165
+ @wraps(original_func)
166
+ def wrapped(*args, **kwargs):
167
+ if not self._active:
168
+ return original_func(*args, **kwargs)
169
+
170
+ try:
171
+ # Deep copy kwargs to avoid modifying the original
172
+ kwargs_copy = kwargs.copy() if kwargs is not None else {}
173
+
174
+ # Handle different calling conventions
175
+ if 'callbacks' not in kwargs_copy:
176
+ kwargs_copy['callbacks'] = [self]
177
+ elif self not in kwargs_copy['callbacks']:
178
+ kwargs_copy['callbacks'].append(self)
179
+
180
+ # Store model name if available
181
+ if component_name in ["OpenAI", "ChatOpenAI_LangchainOpenAI", "ChatOpenAI_ChatModels",
182
+ "ChatVertexAI", "VertexAI", "ChatGoogleGenerativeAI", "ChatAnthropic",
183
+ "ChatLiteLLM", "ChatBedrock", "AzureChatOpenAI", "ChatAnthropicVertex"]:
184
+ instance = args[0] if args else None
185
+ model_name = kwargs.get('model_name') or kwargs.get('model') or kwargs.get('model_id')
186
+
187
+ if instance and model_name:
188
+ self.model_names[id(instance)] = model_name
189
+
190
+ # Try different method signatures
191
+ try:
192
+ # First, try calling with modified kwargs
193
+ return original_func(*args, **kwargs_copy)
194
+ except TypeError:
195
+ # If that fails, try without kwargs
196
+ try:
197
+ return original_func(*args)
198
+ except Exception as e:
199
+ # If all else fails, use original call
200
+ logger.error(f"Failed to invoke {component_name} with modified callbacks: {e}")
201
+ return original_func(*args, **kwargs)
202
+
203
+ except Exception as e:
204
+ # Log any errors that occur during the function call
205
+ logger.error(f"Error in {component_name} wrapper: {e}")
206
+
207
+ # Record the error using the tracer's error handling method
208
+ self.on_error(e, context=f"wrapper_{component_name}")
209
+
210
+ # Fallback to calling the original function without modifications
211
+ return original_func(*args, **kwargs)
212
+
213
+ @wraps(original_func)
214
+ def wrapped_invoke(*args, **kwargs):
215
+ if not self._active:
216
+ return original_func(*args, **kwargs)
217
+
218
+ try:
219
+ # Deep copy kwargs to avoid modifying the original
220
+ kwargs_copy = kwargs.copy() if kwargs is not None else {}
221
+
222
+ # Handle different calling conventions
223
+ if 'config' not in kwargs_copy:
224
+ kwargs_copy['config'] = {'callbacks': [self]}
225
+ elif 'callbacks' not in kwargs_copy['config']:
226
+ kwargs_copy['config']['callbacks'] = [self]
227
+ elif self not in kwargs_copy['config']['callbacks']:
228
+ kwargs_copy['config']['callbacks'].append(self)
229
+
230
+ # Store model name if available
231
+ if component_name in ["OpenAI", "ChatOpenAI_LangchainOpenAI", "ChatOpenAI_ChatModels",
232
+ "ChatVertexAI", "VertexAI", "ChatGoogleGenerativeAI", "ChatAnthropic",
233
+ "ChatLiteLLM", "ChatBedrock", "AzureChatOpenAI", "ChatAnthropicVertex"]:
234
+ instance = args[0] if args else None
235
+ model_name = kwargs.get('model_name') or kwargs.get('model') or kwargs.get('model_id')
236
+
237
+ if instance and model_name:
238
+ self.model_names[id(instance)] = model_name
239
+
240
+ # Try different method signatures
241
+ try:
242
+ # First, try calling with modified kwargs
243
+ return original_func(*args, **kwargs_copy)
244
+ except TypeError:
245
+ # If that fails, try without kwargs
246
+ try:
247
+ return original_func(*args)
248
+ except Exception as e:
249
+ # If all else fails, use original call
250
+ logger.error(f"Failed to invoke {component_name} with modified callbacks: {e}")
251
+ return original_func(*args, **kwargs)
252
+
253
+ except Exception as e:
254
+ # Log any errors that occur during the function call
255
+ logger.error(f"Error in {component_name} wrapper: {e}")
256
+
257
+ # Record the error using the tracer's error handling method
258
+ self.on_error(e, context=f"wrapper_{component_name}")
259
+
260
+ # Fallback to calling the original function without modifications
261
+ return original_func(*args, **kwargs)
262
+
263
+ if method_name == 'invoke':
264
+ return wrapped_invoke
265
+ return wrapped
266
+
267
+
268
+ def _monkey_patch(self):
269
+ """Enhanced monkey-patching with comprehensive component support"""
270
+ components_to_patch = {}
271
+
272
+ try:
273
+ from langchain.llms import OpenAI
274
+ components_to_patch["OpenAI"] = (OpenAI, "__init__")
275
+ except ImportError:
276
+ logger.debug("OpenAI not available for patching")
277
+
278
+ try:
279
+ from langchain_aws import ChatBedrock
280
+ components_to_patch["ChatBedrock"] = (ChatBedrock, "__init__")
281
+ except ImportError:
282
+ logger.debug("ChatBedrock not available for patching")
283
+
284
+ try:
285
+ from langchain_google_vertexai import ChatVertexAI
286
+ components_to_patch["ChatVertexAI"] = (ChatVertexAI, "__init__")
287
+ except ImportError:
288
+ logger.debug("ChatVertexAI not available for patching")
289
+
290
+ try:
291
+ from langchain_google_vertexai import VertexAI
292
+ components_to_patch["VertexAI"] = (VertexAI, "__init__")
293
+ except ImportError:
294
+ logger.debug("VertexAI not available for patching")
295
+
296
+ try:
297
+ from langchain_google_vertexai.model_garden import ChatAnthropicVertex
298
+ components_to_patch["ChatAnthropicVertex"] = (ChatAnthropicVertex, "__init__")
299
+ except ImportError:
300
+ logger.debug("ChatAnthropicVertex not available for patching")
301
+
302
+ try:
303
+ from langchain_google_genai import ChatGoogleGenerativeAI
304
+ components_to_patch["ChatGoogleGenerativeAI"] = (ChatGoogleGenerativeAI, "__init__")
305
+ except ImportError:
306
+ logger.debug("ChatGoogleGenerativeAI not available for patching")
307
+
308
+ try:
309
+ from langchain_anthropic import ChatAnthropic
310
+ components_to_patch["ChatAnthropic"] = (ChatAnthropic, "__init__")
311
+ except ImportError:
312
+ logger.debug("ChatAnthropic not available for patching")
313
+
314
+ try:
315
+ from langchain_community.chat_models import ChatLiteLLM
316
+ components_to_patch["ChatLiteLLM"] = (ChatLiteLLM, "__init__")
317
+ except ImportError:
318
+ logger.debug("ChatLiteLLM not available for patching")
319
+
320
+ try:
321
+ from langchain_openai import ChatOpenAI as ChatOpenAI_LangchainOpenAI
322
+ components_to_patch["ChatOpenAI_LangchainOpenAI"] = (ChatOpenAI_LangchainOpenAI, "__init__")
323
+ except ImportError:
324
+ logger.debug("ChatOpenAI (from langchain_openai) not available for patching")
325
+
326
+ try:
327
+ from langchain_openai import AzureChatOpenAI
328
+ components_to_patch["AzureChatOpenAI"] = (AzureChatOpenAI, "__init__")
329
+ except ImportError:
330
+ logger.debug("AzureChatOpenAI (from langchain_openai) not available for patching")
331
+
332
+ try:
333
+ from langchain.chat_models import ChatOpenAI as ChatOpenAI_ChatModels
334
+ components_to_patch["ChatOpenAI_ChatModels"] = (ChatOpenAI_ChatModels, "__init__")
335
+ except ImportError:
336
+ logger.debug("ChatOpenAI (from langchain.chat_models) not available for patching")
337
+
338
+ try:
339
+ from langchain.chains import create_retrieval_chain, RetrievalQA
340
+ components_to_patch["RetrievalQA"] = (RetrievalQA, "from_chain_type")
341
+ components_to_patch["create_retrieval_chain"] = (create_retrieval_chain, None)
342
+ components_to_patch['RetrievalQA.invoke'] = (RetrievalQA, 'invoke')
343
+ except ImportError:
344
+ logger.debug("Langchain chains not available for patching")
345
+
346
+ for name, (component, method_name) in components_to_patch.items():
347
+ try:
348
+ if method_name == "__init__":
349
+ original = component.__init__
350
+ self._original_inits[name] = original
351
+ component.__init__ = self._create_safe_wrapper(original, name, method_name)
352
+ elif method_name:
353
+ original = getattr(component, method_name)
354
+ self._original_methods[name] = original
355
+ if isinstance(original, classmethod):
356
+ wrapped = classmethod(
357
+ self._create_safe_wrapper(original.__func__, name, method_name)
358
+ )
359
+ else:
360
+ wrapped = self._create_safe_wrapper(original, name, method_name)
361
+ setattr(component, method_name, wrapped)
362
+ else:
363
+ self._original_methods[name] = component
364
+ globals()[name] = self._create_safe_wrapper(component, name, method_name)
365
+ except Exception as e:
366
+ logger.error(f"Error patching {name}: {e}")
367
+ self.on_error(e, context=f"patch_{name}")
368
+
369
+ def _restore_original_methods(self):
370
+ """Restore all original methods and functions with enhanced error handling"""
371
+ # Dynamically import only what we need based on what was patched
372
+ imported_components = {}
373
+
374
+ if self._original_inits or self._original_methods:
375
+ for name in list(self._original_inits.keys()) + list(self._original_methods.keys()):
376
+ try:
377
+ if name == "OpenAI":
378
+ from langchain.llms import OpenAI
379
+ imported_components[name] = OpenAI
380
+ elif name == "ChatVertexAI":
381
+ from langchain_google_vertexai import ChatVertexAI
382
+ imported_components[name] = ChatVertexAI
383
+ elif name == "VertexAI":
384
+ from langchain_google_vertexai import VertexAI
385
+ imported_components[name] = VertexAI
386
+ elif name == "ChatGoogleGenerativeAI":
387
+ from langchain_google_genai import ChatGoogleGenerativeAI
388
+ imported_components[name] = ChatGoogleGenerativeAI
389
+ elif name == "ChatAnthropic":
390
+ from langchain_anthropic import ChatAnthropic
391
+ imported_components[name] = ChatAnthropic
392
+ elif name == "ChatBedrock":
393
+ from langchain_aws import ChatBedrock
394
+ imported_components[name] = ChatBedrock
395
+ elif name == "AzureChatOpenAI":
396
+ from langchain_openai import AzureChatOpenAI
397
+ imported_components[name] = AzureChatOpenAI
398
+ elif name == "ChatAnthropicVertex":
399
+ from langchain_google_vertexai.model_garden import ChatAnthropicVertex
400
+ imported_components[name] = ChatAnthropicVertex
401
+ elif name == "ChatLiteLLM":
402
+ from langchain_community.chat_models import ChatLiteLLM
403
+ imported_components[name] = ChatLiteLLM
404
+ elif name == "ChatOpenAI_LangchainOpenAI":
405
+ from langchain_openai import ChatOpenAI as ChatOpenAI_LangchainOpenAI
406
+ imported_components[name] = ChatOpenAI_LangchainOpenAI
407
+ elif name == "ChatOpenAI_ChatModels":
408
+ from langchain.chat_models import ChatOpenAI as ChatOpenAI_ChatModels
409
+ imported_components[name] = ChatOpenAI_ChatModels
410
+ elif name in ["RetrievalQA", "create_retrieval_chain", 'RetrievalQA.invoke']:
411
+ from langchain.chains import create_retrieval_chain, RetrievalQA
412
+ imported_components["RetrievalQA"] = RetrievalQA
413
+ imported_components["create_retrieval_chain"] = create_retrieval_chain
414
+ except ImportError:
415
+ logger.debug(f"{name} not available for restoration")
416
+
417
+ for name, original in self._original_inits.items():
418
+ try:
419
+ if name in imported_components:
420
+ component = imported_components[name]
421
+ component.__init__ = original
422
+ except Exception as e:
423
+ logger.error(f"Error restoring {name}: {e}")
424
+ self.on_error(e, context=f"restore_{name}")
425
+
426
+ # Restore original methods and functions
427
+ for name, original in self._original_methods.items():
428
+ try:
429
+ if "." in name:
430
+ module_name, method_name = name.rsplit(".", 1)
431
+ if module_name in imported_components:
432
+ module = imported_components[module_name]
433
+ setattr(module, method_name, original)
434
+ else:
435
+ if name in imported_components:
436
+ globals()[name] = original
437
+ except Exception as e:
438
+ logger.error(f"Error restoring {name}: {e}")
439
+ self.on_error(e, context=f"restore_{name}")
440
+
441
+ def start(self):
442
+ """Start tracing with enhanced error handling and async support"""
443
+ try:
444
+ self.reset_trace()
445
+ self.current_trace["start_time"] = datetime.now()
446
+ self._active = True
447
+ self._monkey_patch()
448
+
449
+ if self.save_interval:
450
+ loop = asyncio.get_event_loop()
451
+ self._save_task = loop.create_task(self._periodic_save())
452
+
453
+ logger.info("Tracing started")
454
+ except Exception as e:
455
+ logger.error(f"Error starting tracer: {e}")
456
+ self.on_error(e, context="start")
457
+ raise
458
+
459
+ def stop(self):
460
+ """Stop tracing with enhanced cleanup"""
461
+ try:
462
+ self._active = False
463
+ if self._save_task:
464
+ self._save_task.cancel()
465
+ self._restore_original_methods()
466
+ # self._save_trace(force=True)
467
+
468
+ return self.current_trace.copy(), self.additional_metadata
469
+
470
+ logger.info("Tracing stopped")
471
+ except Exception as e:
472
+ logger.error(f"Error stopping tracer: {e}")
473
+ self.on_error(e, context="stop")
474
+ raise
475
+ finally:
476
+ self._original_inits.clear()
477
+ self._original_methods.clear()
478
+
479
+ def force_save(self):
480
+ """Force save the current trace"""
481
+ self._save_trace(force=True)
482
+
483
+ # Callback methods with enhanced error handling and logging
484
+ def on_llm_start(
485
+ self,
486
+ serialized: Dict[str, Any],
487
+ prompts: List[str],
488
+ run_id: UUID,
489
+ **kwargs: Any,
490
+ ) -> None:
491
+ try:
492
+ if not self.current_trace["start_time"]:
493
+ self.current_trace["start_time"] = datetime.now()
494
+
495
+ self.current_trace["llm_calls"].append(
496
+ {
497
+ "timestamp": datetime.now(),
498
+ "event": "llm_start",
499
+ "serialized": serialized,
500
+ "prompts": prompts,
501
+ "run_id": str(run_id),
502
+ "additional_kwargs": kwargs,
503
+ }
504
+ )
505
+ except Exception as e:
506
+ self.on_error(e, context="llm_start")
507
+
508
+ def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
509
+ try:
510
+ self.current_trace["llm_calls"].append(
511
+ {
512
+ "timestamp": datetime.now(),
513
+ "event": "llm_end",
514
+ "response": response.dict(),
515
+ "run_id": str(run_id),
516
+ "additional_kwargs": kwargs,
517
+ }
518
+ )
519
+
520
+ # Calculate latency
521
+ end_time = datetime.now()
522
+ latency = (end_time - self.current_trace["start_time"]).total_seconds()
523
+
524
+ # Check if values are there in llm_output
525
+ model = ""
526
+ prompt_tokens = 0
527
+ completion_tokens = 0
528
+ total_tokens = 0
529
+
530
+ # Try to get model name from llm_output first
531
+ if response and response.llm_output:
532
+ try:
533
+ model = response.llm_output.get("model_name")
534
+ if not model:
535
+ model = response.llm_output.get("model", "")
536
+ except Exception as e:
537
+ # logger.debug(f"Error getting model name: {e}")
538
+ model = ""
539
+
540
+ # Add model name
541
+ if not model:
542
+ try:
543
+ model = response.llm_output.get("model_name")
544
+ if not model:
545
+ model = response.llm_output.get("model", "")
546
+ except Exception as e:
547
+ # logger.debug(f"Error getting model name: {e}")
548
+ model = ""
549
+
550
+
551
+ # Add token usage
552
+ try:
553
+ token_usage = response.llm_output.get("token_usage", {})
554
+ if token_usage=={}:
555
+ try:
556
+ token_usage = response.llm_output.get("usage")
557
+ except Exception as e:
558
+ # logger.debug(f"Error getting token usage: {e}")
559
+ token_usage = {}
560
+
561
+ if token_usage !={}:
562
+ prompt_tokens = token_usage.get("prompt_tokens", 0)
563
+ if prompt_tokens==0:
564
+ prompt_tokens = token_usage.get("input_tokens", 0)
565
+ completion_tokens = token_usage.get("completion_tokens", 0)
566
+ if completion_tokens==0:
567
+ completion_tokens = token_usage.get("output_tokens", 0)
568
+
569
+ total_tokens = prompt_tokens + completion_tokens
570
+ except Exception as e:
571
+ # logger.debug(f"Error getting token usage: {e}")
572
+ prompt_tokens = 0
573
+ completion_tokens = 0
574
+ total_tokens = 0
575
+
576
+ # Check if values are there in
577
+ if prompt_tokens == 0 and completion_tokens == 0:
578
+ try:
579
+ usage_data = response.generations[0][0].message.usage_metadata
580
+ prompt_tokens = usage_data.get("input_tokens", 0)
581
+ completion_tokens = usage_data.get("output_tokens", 0)
582
+ total_tokens = prompt_tokens + completion_tokens
583
+ except Exception as e:
584
+ # logger.debug(f"Error getting usage data: {e}")
585
+ try:
586
+ usage_data = response.generations[0][0].generation_info['usage_metadata']
587
+ prompt_tokens = usage_data.get("prompt_token_count", 0)
588
+ completion_tokens = usage_data.get("candidates_token_count", 0)
589
+ total_tokens = prompt_tokens + completion_tokens
590
+ except Exception as e:
591
+ # logger.debug(f"Error getting token usage: {e}")
592
+ prompt_tokens = 0
593
+ completion_tokens = 0
594
+ total_tokens = 0
595
+
596
+ # If no model name in llm_output, try to get it from stored model names
597
+ try:
598
+ if model == "":
599
+ model = list(self.model_names.values())[0]
600
+ except Exception as e:
601
+ model=""
602
+
603
+ self.additional_metadata = {
604
+ 'latency': latency,
605
+ 'model_name': model,
606
+ 'tokens': {
607
+ 'prompt': prompt_tokens,
608
+ 'completion': completion_tokens,
609
+ 'total': total_tokens
610
+ }
611
+ }
612
+
613
+ except Exception as e:
614
+ self.on_error(e, context="llm_end")
615
+
616
+ def on_chat_model_start(
617
+ self,
618
+ serialized: Dict[str, Any],
619
+ messages: List[List[BaseMessage]],
620
+ *,
621
+ run_id: UUID,
622
+ **kwargs: Any,
623
+ ) -> None:
624
+ try:
625
+ messages_dict = [
626
+ [
627
+ {
628
+ "type": msg.type,
629
+ "content": msg.content,
630
+ "additional_kwargs": msg.additional_kwargs,
631
+ }
632
+ for msg in batch
633
+ ]
634
+ for batch in messages
635
+ ]
636
+
637
+ self.current_trace["chat_model_calls"].append(
638
+ {
639
+ "timestamp": datetime.now(),
640
+ "event": "chat_model_start",
641
+ "serialized": serialized,
642
+ "messages": messages_dict,
643
+ "run_id": str(run_id),
644
+ "additional_kwargs": kwargs,
645
+ }
646
+ )
647
+ except Exception as e:
648
+ self.on_error(e, context="chat_model_start")
649
+
650
+ def on_chain_start(
651
+ self,
652
+ serialized: Dict[str, Any],
653
+ inputs: Dict[str, Any],
654
+ *,
655
+ run_id: UUID,
656
+ **kwargs: Any,
657
+ ) -> None:
658
+ try:
659
+ context = ""
660
+ query = ""
661
+ if isinstance(inputs, dict):
662
+ if "context" in inputs:
663
+ if isinstance(inputs["context"], Document):
664
+ context = inputs["context"].page_content
665
+ elif isinstance(inputs["context"], list):
666
+ context = "\n".join(
667
+ doc.page_content if isinstance(doc, Document) else str(doc)
668
+ for doc in inputs["context"]
669
+ )
670
+ elif isinstance(inputs["context"], str):
671
+ context = inputs["context"]
672
+
673
+ query = inputs.get("question", inputs.get("input", ""))
674
+
675
+ # Set the current query
676
+ self._current_query = query
677
+
678
+ chain_event = {
679
+ "timestamp": datetime.now(),
680
+ "serialized": serialized,
681
+ "context": context,
682
+ "query": inputs.get("question", inputs.get("input", "")),
683
+ "run_id": str(run_id),
684
+ "additional_kwargs": kwargs,
685
+ }
686
+
687
+ self.current_trace["chain_starts"].append(chain_event)
688
+ except Exception as e:
689
+ self.on_error(e, context="chain_start")
690
+
691
+ def on_chain_end(
692
+ self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any
693
+ ) -> None:
694
+ try:
695
+ self.current_trace["chain_ends"].append(
696
+ {
697
+ "timestamp": datetime.now(),
698
+ "outputs": outputs,
699
+ "run_id": str(run_id),
700
+ "additional_kwargs": kwargs,
701
+ }
702
+ )
703
+ except Exception as e:
704
+ self.on_error(e, context="chain_end")
705
+
706
+ def on_agent_action(self, action: AgentAction, run_id: UUID, **kwargs: Any) -> None:
707
+ try:
708
+ self.current_trace["agent_actions"].append(
709
+ {
710
+ "timestamp": datetime.now(),
711
+ "action": action.dict(),
712
+ "run_id": str(run_id),
713
+ "additional_kwargs": kwargs,
714
+ }
715
+ )
716
+ except Exception as e:
717
+ self.on_error(e, context="agent_action")
718
+
719
+ def on_agent_finish(self, finish: AgentFinish, run_id: UUID, **kwargs: Any) -> None:
720
+ try:
721
+ self.current_trace["agent_actions"].append(
722
+ {
723
+ "timestamp": datetime.now(),
724
+ "event": "agent_finish",
725
+ "finish": finish.dict(),
726
+ "run_id": str(run_id),
727
+ "additional_kwargs": kwargs,
728
+ }
729
+ )
730
+ except Exception as e:
731
+ self.on_error(e, context="agent_finish")
732
+
733
+ def on_retriever_start(
734
+ self, serialized: Dict[str, Any], query: str, *, run_id: UUID, **kwargs: Any
735
+ ) -> None:
736
+ try:
737
+ retriever_event = {
738
+ "timestamp": datetime.now(),
739
+ "event": "retriever_start",
740
+ "serialized": serialized,
741
+ "query": query,
742
+ "run_id": str(run_id),
743
+ "additional_kwargs": kwargs,
744
+ }
745
+
746
+ self.current_trace["retriever_actions"].append(retriever_event)
747
+ except Exception as e:
748
+ self.on_error(e, context="retriever_start")
749
+
750
+ def on_retriever_end(
751
+ self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
752
+ ) -> None:
753
+ try:
754
+ processed_documents = [
755
+ {"page_content": doc.page_content, "metadata": doc.metadata}
756
+ for doc in documents
757
+ ]
758
+
759
+ retriever_event = {
760
+ "timestamp": datetime.now(),
761
+ "event": "retriever_end",
762
+ "documents": processed_documents,
763
+ "run_id": str(run_id),
764
+ "additional_kwargs": kwargs,
765
+ }
766
+
767
+ self.current_trace["retriever_actions"].append(retriever_event)
768
+ except Exception as e:
769
+ self.on_error(e, context="retriever_end")
770
+
771
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
772
+ try:
773
+ self.current_trace["tokens"].append(
774
+ {
775
+ "timestamp": datetime.now(),
776
+ "event": "new_token",
777
+ "token": token,
778
+ "additional_kwargs": kwargs,
779
+ }
780
+ )
781
+ except Exception as e:
782
+ self.on_error(e, context="llm_new_token")
783
+
784
+ def on_error(self, error: Exception, context: str = "", **kwargs: Any) -> None:
785
+ """Enhanced error handling with context"""
786
+ try:
787
+ error_event = {
788
+ "timestamp": datetime.now(),
789
+ "error": str(error),
790
+ "error_type": type(error).__name__,
791
+ "context": context,
792
+ "additional_kwargs": kwargs,
793
+ }
794
+ self.current_trace["errors"].append(error_event)
795
+ logger.error(f"Error in {context}: {error}")
796
+ except Exception as e:
797
+ logger.critical(f"Error in error handler: {e}")
798
+
799
+ def on_chain_error(self, error: Exception, **kwargs: Any) -> None:
800
+ self.on_error(error, context="chain", **kwargs)
801
+
802
+ def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
803
+ self.on_error(error, context="llm", **kwargs)
804
+
805
+ def on_tool_error(self, error: Exception, **kwargs: Any) -> None:
806
+ self.on_error(error, context="tool", **kwargs)
807
+
808
+ def on_retriever_error(self, error: Exception, **kwargs: Any) -> None:
809
+ self.on_error(error, context="retriever", **kwargs)