ragaai-catalyst 2.1b0__py3-none-any.whl → 2.1b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ragaai_catalyst/__init__.py +1 -0
  2. ragaai_catalyst/dataset.py +1 -4
  3. ragaai_catalyst/evaluation.py +4 -5
  4. ragaai_catalyst/guard_executor.py +97 -0
  5. ragaai_catalyst/guardrails_manager.py +41 -15
  6. ragaai_catalyst/internal_api_completion.py +1 -1
  7. ragaai_catalyst/prompt_manager.py +7 -2
  8. ragaai_catalyst/ragaai_catalyst.py +1 -1
  9. ragaai_catalyst/synthetic_data_generation.py +7 -0
  10. ragaai_catalyst/tracers/__init__.py +1 -1
  11. ragaai_catalyst/tracers/agentic_tracing/__init__.py +3 -0
  12. ragaai_catalyst/tracers/agentic_tracing/agent_tracer.py +422 -0
  13. ragaai_catalyst/tracers/agentic_tracing/agentic_tracing.py +198 -0
  14. ragaai_catalyst/tracers/agentic_tracing/base.py +376 -0
  15. ragaai_catalyst/tracers/agentic_tracing/data_structure.py +248 -0
  16. ragaai_catalyst/tracers/agentic_tracing/examples/FinancialAnalysisSystem.ipynb +536 -0
  17. ragaai_catalyst/tracers/agentic_tracing/examples/GameActivityEventPlanner.ipynb +134 -0
  18. ragaai_catalyst/tracers/agentic_tracing/examples/TravelPlanner.ipynb +563 -0
  19. ragaai_catalyst/tracers/agentic_tracing/file_name_tracker.py +46 -0
  20. ragaai_catalyst/tracers/agentic_tracing/llm_tracer.py +808 -0
  21. ragaai_catalyst/tracers/agentic_tracing/network_tracer.py +286 -0
  22. ragaai_catalyst/tracers/agentic_tracing/sample.py +197 -0
  23. ragaai_catalyst/tracers/agentic_tracing/tool_tracer.py +247 -0
  24. ragaai_catalyst/tracers/agentic_tracing/unique_decorator.py +165 -0
  25. ragaai_catalyst/tracers/agentic_tracing/unique_decorator_test.py +172 -0
  26. ragaai_catalyst/tracers/agentic_tracing/upload_agentic_traces.py +187 -0
  27. ragaai_catalyst/tracers/agentic_tracing/upload_code.py +115 -0
  28. ragaai_catalyst/tracers/agentic_tracing/user_interaction_tracer.py +43 -0
  29. ragaai_catalyst/tracers/agentic_tracing/utils/__init__.py +3 -0
  30. ragaai_catalyst/tracers/agentic_tracing/utils/api_utils.py +18 -0
  31. ragaai_catalyst/tracers/agentic_tracing/utils/data_classes.py +61 -0
  32. ragaai_catalyst/tracers/agentic_tracing/utils/generic.py +32 -0
  33. ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +177 -0
  34. ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +7823 -0
  35. ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +74 -0
  36. ragaai_catalyst/tracers/agentic_tracing/zip_list_of_unique_files.py +342 -0
  37. ragaai_catalyst/tracers/exporters/raga_exporter.py +1 -7
  38. ragaai_catalyst/tracers/tracer.py +30 -4
  39. ragaai_catalyst/tracers/upload_traces.py +127 -0
  40. ragaai_catalyst-2.1b1.dist-info/METADATA +43 -0
  41. ragaai_catalyst-2.1b1.dist-info/RECORD +56 -0
  42. {ragaai_catalyst-2.1b0.dist-info → ragaai_catalyst-2.1b1.dist-info}/WHEEL +1 -1
  43. ragaai_catalyst-2.1b0.dist-info/METADATA +0 -295
  44. ragaai_catalyst-2.1b0.dist-info/RECORD +0 -28
  45. {ragaai_catalyst-2.1b0.dist-info → ragaai_catalyst-2.1b1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,808 @@
1
+ from typing import Optional, Any, Dict, List
2
+ import asyncio
3
+ import psutil
4
+ import wrapt
5
+ import functools
6
+ from datetime import datetime
7
+ import uuid
8
+ import contextvars
9
+ import traceback
10
+
11
+ from .unique_decorator import generate_unique_hash_simple
12
+ from .utils.trace_utils import load_model_costs
13
+ from .utils.llm_utils import extract_llm_output
14
+ from .file_name_tracker import TrackName
15
+
16
+
17
+ class LLMTracerMixin:
18
+ def __init__(self, *args, **kwargs):
19
+ super().__init__(*args, **kwargs)
20
+ self.file_tracker = TrackName()
21
+ self.patches = []
22
+ try:
23
+ self.model_costs = load_model_costs()
24
+ except Exception as e:
25
+ self.model_costs = {
26
+ # TODO: Default cost handling needs to be improved
27
+ "default": {
28
+ "input_cost_per_token": 0.0,
29
+ "output_cost_per_token": 0.0
30
+ }
31
+ }
32
+ self.current_llm_call_name = contextvars.ContextVar("llm_call_name", default=None)
33
+ self.component_network_calls = {}
34
+ self.component_user_interaction = {}
35
+ self.current_component_id = None
36
+ self.total_tokens = 0
37
+ self.total_cost = 0.0
38
+ self.llm_data = {}
39
+
40
+ def instrument_llm_calls(self):
41
+ # Handle modules that are already imported
42
+ import sys
43
+
44
+ if "vertexai" in sys.modules:
45
+ self.patch_vertex_ai_methods(sys.modules["vertexai"])
46
+ if "vertexai.generative_models" in sys.modules:
47
+ self.patch_vertex_ai_methods(sys.modules["vertexai.generative_models"])
48
+
49
+ if "openai" in sys.modules:
50
+ self.patch_openai_methods(sys.modules["openai"])
51
+ if "litellm" in sys.modules:
52
+ self.patch_litellm_methods(sys.modules["litellm"])
53
+ if "anthropic" in sys.modules:
54
+ self.patch_anthropic_methods(sys.modules["anthropic"])
55
+ if "google.generativeai" in sys.modules:
56
+ self.patch_google_genai_methods(sys.modules["google.generativeai"])
57
+ if "langchain_google_vertexai" in sys.modules:
58
+ self.patch_langchain_google_methods(sys.modules["langchain_google_vertexai"])
59
+ if "langchain_google_genai" in sys.modules:
60
+ self.patch_langchain_google_methods(sys.modules["langchain_google_genai"])
61
+
62
+ # Register hooks for future imports
63
+ wrapt.register_post_import_hook(self.patch_vertex_ai_methods, "vertexai")
64
+ wrapt.register_post_import_hook(self.patch_vertex_ai_methods, "vertexai.generative_models")
65
+ wrapt.register_post_import_hook(self.patch_openai_methods, "openai")
66
+ wrapt.register_post_import_hook(self.patch_litellm_methods, "litellm")
67
+ wrapt.register_post_import_hook(self.patch_anthropic_methods, "anthropic")
68
+ wrapt.register_post_import_hook(self.patch_google_genai_methods, "google.generativeai")
69
+
70
+ # Add hooks for LangChain integrations
71
+ wrapt.register_post_import_hook(self.patch_langchain_google_methods, "langchain_google_vertexai")
72
+ wrapt.register_post_import_hook(self.patch_langchain_google_methods, "langchain_google_genai")
73
+
74
+ def patch_openai_methods(self, module):
75
+ try:
76
+ if hasattr(module, "OpenAI"):
77
+ client_class = getattr(module, "OpenAI")
78
+ self.wrap_openai_client_methods(client_class)
79
+ if hasattr(module, "AsyncOpenAI"):
80
+ async_client_class = getattr(module, "AsyncOpenAI")
81
+ self.wrap_openai_client_methods(async_client_class)
82
+ except Exception as e:
83
+ # Log the error but continue execution
84
+ print(f"Warning: Failed to patch OpenAI methods: {str(e)}")
85
+
86
+ def patch_anthropic_methods(self, module):
87
+ if hasattr(module, "Anthropic"):
88
+ client_class = getattr(module, "Anthropic")
89
+ self.wrap_anthropic_client_methods(client_class)
90
+
91
+ def patch_google_genai_methods(self, module):
92
+ # Patch direct Google GenerativeAI usage
93
+ if hasattr(module, "GenerativeModel"):
94
+ model_class = getattr(module, "GenerativeModel")
95
+ self.wrap_genai_model_methods(model_class)
96
+
97
+ # Patch LangChain integration
98
+ if hasattr(module, "ChatGoogleGenerativeAI"):
99
+ chat_class = getattr(module, "ChatGoogleGenerativeAI")
100
+ # Wrap invoke method to capture messages
101
+ original_invoke = chat_class.invoke
102
+
103
+ def patched_invoke(self, messages, *args, **kwargs):
104
+ # Store messages in the instance for later use
105
+ self._last_messages = messages
106
+ return original_invoke(self, messages, *args, **kwargs)
107
+
108
+ chat_class.invoke = patched_invoke
109
+
110
+ # LangChain v0.2+ uses invoke/ainvoke
111
+ self.wrap_method(chat_class, "_generate")
112
+ if hasattr(chat_class, "_agenerate"):
113
+ self.wrap_method(chat_class, "_agenerate")
114
+ # Fallback for completion methods
115
+ if hasattr(chat_class, "complete"):
116
+ self.wrap_method(chat_class, "complete")
117
+ if hasattr(chat_class, "acomplete"):
118
+ self.wrap_method(chat_class, "acomplete")
119
+
120
+ def patch_vertex_ai_methods(self, module):
121
+ # Patch the GenerativeModel class
122
+ if hasattr(module, "generative_models"):
123
+ gen_models = getattr(module, "generative_models")
124
+ if hasattr(gen_models, "GenerativeModel"):
125
+ model_class = getattr(gen_models, "GenerativeModel")
126
+ self.wrap_vertex_model_methods(model_class)
127
+
128
+ # Also patch the class directly if available
129
+ if hasattr(module, "GenerativeModel"):
130
+ model_class = getattr(module, "GenerativeModel")
131
+ self.wrap_vertex_model_methods(model_class)
132
+
133
+ def wrap_vertex_model_methods(self, model_class):
134
+ # Patch both sync and async methods
135
+ self.wrap_method(model_class, "generate_content")
136
+ if hasattr(model_class, "generate_content_async"):
137
+ self.wrap_method(model_class, "generate_content_async")
138
+
139
+ def patch_litellm_methods(self, module):
140
+ self.wrap_method(module, "completion")
141
+ self.wrap_method(module, "acompletion")
142
+
143
+ def patch_langchain_google_methods(self, module):
144
+ """Patch LangChain's Google integration methods"""
145
+ if hasattr(module, "ChatVertexAI"):
146
+ chat_class = getattr(module, "ChatVertexAI")
147
+ # LangChain v0.2+ uses invoke/ainvoke
148
+ self.wrap_method(chat_class, "_generate")
149
+ if hasattr(chat_class, "_agenerate"):
150
+ self.wrap_method(chat_class, "_agenerate")
151
+ # Fallback for completion methods
152
+ if hasattr(chat_class, "complete"):
153
+ self.wrap_method(chat_class, "complete")
154
+ if hasattr(chat_class, "acomplete"):
155
+ self.wrap_method(chat_class, "acomplete")
156
+
157
+ if hasattr(module, "ChatGoogleGenerativeAI"):
158
+ chat_class = getattr(module, "ChatGoogleGenerativeAI")
159
+ # LangChain v0.2+ uses invoke/ainvoke
160
+ self.wrap_method(chat_class, "_generate")
161
+ if hasattr(chat_class, "_agenerate"):
162
+ self.wrap_method(chat_class, "_agenerate")
163
+ # Fallback for completion methods
164
+ if hasattr(chat_class, "complete"):
165
+ self.wrap_method(chat_class, "complete")
166
+ if hasattr(chat_class, "acomplete"):
167
+ self.wrap_method(chat_class, "acomplete")
168
+
169
+ def wrap_openai_client_methods(self, client_class):
170
+ original_init = client_class.__init__
171
+
172
+ @functools.wraps(original_init)
173
+ def patched_init(client_self, *args, **kwargs):
174
+ original_init(client_self, *args, **kwargs)
175
+ self.wrap_method(client_self.chat.completions, "create")
176
+ if hasattr(client_self.chat.completions, "acreate"):
177
+ self.wrap_method(client_self.chat.completions, "acreate")
178
+
179
+ setattr(client_class, "__init__", patched_init)
180
+
181
+ def wrap_anthropic_client_methods(self, client_class):
182
+ original_init = client_class.__init__
183
+
184
+ @functools.wraps(original_init)
185
+ def patched_init(client_self, *args, **kwargs):
186
+ original_init(client_self, *args, **kwargs)
187
+ self.wrap_method(client_self.messages, "create")
188
+ if hasattr(client_self.messages, "acreate"):
189
+ self.wrap_method(client_self.messages, "acreate")
190
+
191
+ setattr(client_class, "__init__", patched_init)
192
+
193
+ def wrap_genai_model_methods(self, model_class):
194
+ original_init = model_class.__init__
195
+
196
+ @functools.wraps(original_init)
197
+ def patched_init(model_self, *args, **kwargs):
198
+ original_init(model_self, *args, **kwargs)
199
+ self.wrap_method(model_self, "generate_content")
200
+ if hasattr(model_self, "generate_content_async"):
201
+ self.wrap_method(model_self, "generate_content_async")
202
+
203
+ setattr(model_class, "__init__", patched_init)
204
+
205
+ def wrap_method(self, obj, method_name):
206
+ """
207
+ Wrap a method with tracing functionality.
208
+ Works for both class methods and instance methods.
209
+ """
210
+ # If obj is a class, we need to patch both the class and any existing instances
211
+ if isinstance(obj, type):
212
+ # Store the original class method
213
+ original_method = getattr(obj, method_name)
214
+
215
+ @wrapt.decorator
216
+ def wrapper(wrapped, instance, args, kwargs):
217
+ if asyncio.iscoroutinefunction(wrapped):
218
+ return self.trace_llm_call(wrapped, *args, **kwargs)
219
+ return self.trace_llm_call_sync(wrapped, *args, **kwargs)
220
+
221
+ # Wrap the class method
222
+ wrapped_method = wrapper(original_method)
223
+ setattr(obj, method_name, wrapped_method)
224
+ self.patches.append((obj, method_name, original_method))
225
+
226
+ else:
227
+ # For instance methods
228
+ original_method = getattr(obj, method_name)
229
+
230
+ @wrapt.decorator
231
+ def wrapper(wrapped, instance, args, kwargs):
232
+ if asyncio.iscoroutinefunction(wrapped):
233
+ return self.trace_llm_call(wrapped, *args, **kwargs)
234
+ return self.trace_llm_call_sync(wrapped, *args, **kwargs)
235
+
236
+ wrapped_method = wrapper(original_method)
237
+ setattr(obj, method_name, wrapped_method)
238
+ self.patches.append((obj, method_name, original_method))
239
+
240
+ def _extract_model_name(self, args, kwargs, result):
241
+ """Extract model name from kwargs or result"""
242
+ # First try direct model parameter
243
+ model = kwargs.get("model", "")
244
+
245
+ if not model:
246
+ # Try to get from instance
247
+ instance = kwargs.get("self", None)
248
+ if instance:
249
+ # Try model_name first (Google format)
250
+ if hasattr(instance, "model_name"):
251
+ model = instance.model_name
252
+ # Try model attribute
253
+ elif hasattr(instance, "model"):
254
+ model = instance.model
255
+
256
+ # TODO: This way isn't scalable. The necessity for normalising model names needs to be fixed. We shouldn't have to do this
257
+ # Normalize Google model names
258
+ if model and isinstance(model, str):
259
+ model = model.lower()
260
+ if "gemini-1.5-flash" in model:
261
+ return "gemini-1.5-flash"
262
+ if "gemini-1.5-pro" in model:
263
+ return "gemini-1.5-pro"
264
+ if "gemini-pro" in model:
265
+ return "gemini-pro"
266
+
267
+ if 'to_dict' in dir(result):
268
+ result = result.to_dict()
269
+ if 'model_version' in result:
270
+ model = result['model_version']
271
+
272
+ return model or "default"
273
+
274
+ def _extract_parameters(self, kwargs):
275
+ """Extract all non-null parameters from kwargs"""
276
+ parameters = {k: v for k, v in kwargs.items() if v is not None}
277
+
278
+ # Remove contents key in parameters (Google LLM Response)
279
+ if 'contents' in parameters:
280
+ del parameters['contents']
281
+
282
+ # Remove messages key in parameters (OpenAI message)
283
+ if 'messages' in parameters:
284
+ del parameters['messages']
285
+
286
+ if 'generation_config' in parameters:
287
+ generation_config = parameters['generation_config']
288
+ # If generation_config is already a dict, use it directly
289
+ if isinstance(generation_config, dict):
290
+ config_dict = generation_config
291
+ else:
292
+ # Convert GenerationConfig to dictionary if it has a to_dict method, otherwise try to get its __dict__
293
+ config_dict = getattr(generation_config, 'to_dict', lambda: generation_config.__dict__)()
294
+ parameters.update(config_dict)
295
+ del parameters['generation_config']
296
+
297
+ return parameters
298
+
299
+ def _extract_token_usage(self, result):
300
+ """Extract token usage from result"""
301
+ # Handle coroutines
302
+ if asyncio.iscoroutine(result):
303
+ # Get the current event loop
304
+ loop = asyncio.get_event_loop()
305
+ # Run the coroutine in the current event loop
306
+ result = loop.run_until_complete(result)
307
+
308
+
309
+ # Handle standard OpenAI/Anthropic format
310
+ if hasattr(result, "usage"):
311
+ usage = result.usage
312
+ return {
313
+ "prompt_tokens": getattr(usage, "prompt_tokens", 0),
314
+ "completion_tokens": getattr(usage, "completion_tokens", 0),
315
+ "total_tokens": getattr(usage, "total_tokens", 0)
316
+ }
317
+
318
+ # Handle Google GenerativeAI format with usage_metadata
319
+ if hasattr(result, "usage_metadata"):
320
+ metadata = result.usage_metadata
321
+ return {
322
+ "prompt_tokens": getattr(metadata, "prompt_token_count", 0),
323
+ "completion_tokens": getattr(metadata, "candidates_token_count", 0),
324
+ "total_tokens": getattr(metadata, "total_token_count", 0)
325
+ }
326
+
327
+ # Handle Vertex AI format
328
+ if hasattr(result, "text"):
329
+ # For LangChain ChatVertexAI
330
+ total_tokens = getattr(result, "token_count", 0)
331
+ if not total_tokens and hasattr(result, "_raw_response"):
332
+ # Try to get from raw response
333
+ total_tokens = getattr(result._raw_response, "token_count", 0)
334
+ return {
335
+ # TODO: This implementation is incorrect. Vertex AI does provide this breakdown
336
+ "prompt_tokens": 0, # Vertex AI doesn't provide this breakdown
337
+ "completion_tokens": total_tokens,
338
+ "total_tokens": total_tokens
339
+ }
340
+
341
+ return { # TODO: Passing 0 in case of not recorded is not correct. This needs to be fixes. Discuss before making changes to this
342
+ "prompt_tokens": 0,
343
+ "completion_tokens": 0,
344
+ "total_tokens": 0
345
+ }
346
+
347
+ def _extract_input_data(self, args, kwargs, result):
348
+ return {
349
+ 'args': args,
350
+ 'kwargs': kwargs
351
+ }
352
+
353
+ def _calculate_cost(self, token_usage, model_name):
354
+ # TODO: Passing default cost is a faulty logic & implementation and should be fixed
355
+ """Calculate cost based on token usage and model"""
356
+ if not isinstance(token_usage, dict):
357
+ token_usage = {
358
+ "prompt_tokens": 0,
359
+ "completion_tokens": 0,
360
+ "total_tokens": token_usage if isinstance(token_usage, (int, float)) else 0
361
+ }
362
+
363
+ # TODO: This is a temporary fix. This needs to be fixed
364
+
365
+ # Get model costs, defaulting to Vertex AI PaLM2 costs if unknown
366
+ model_cost = self.model_costs.get(model_name, {
367
+ "input_cost_per_token": 0.0,
368
+ "output_cost_per_token": 0.0
369
+ })
370
+
371
+ input_cost = (token_usage.get("prompt_tokens", 0)) * model_cost.get("input_cost_per_token", 0.0)
372
+ output_cost = (token_usage.get("completion_tokens", 0)) * model_cost.get("output_cost_per_token", 0.0)
373
+ total_cost = input_cost + output_cost
374
+
375
+ # TODO: Return the value as it is, no need to round
376
+ return {
377
+ "input_cost": round(input_cost, 10),
378
+ "output_cost": round(output_cost, 10),
379
+ "total_cost": round(total_cost, 10)
380
+ }
381
+
382
+ def create_llm_component(self, component_id, hash_id, name, llm_type, version, memory_used, start_time, end_time, input_data, output_data, cost={}, usage={}, error=None, parameters={}):
383
+ # Update total metrics
384
+ self.total_tokens += usage.get("total_tokens", 0)
385
+ self.total_cost += cost.get("total_cost", 0)
386
+
387
+ component = {
388
+ "id": component_id,
389
+ "hash_id": hash_id,
390
+ "source_hash_id": None,
391
+ "type": "llm",
392
+ "name": name,
393
+ "start_time": start_time.isoformat(),
394
+ "end_time": end_time.isoformat(),
395
+ "error": error,
396
+ "parent_id": self.current_agent_id.get(),
397
+ "info": {
398
+ "model": llm_type,
399
+ "version": version,
400
+ "memory_used": memory_used,
401
+ "cost": cost,
402
+ "tokens": usage,
403
+ **parameters
404
+ },
405
+ "data": {
406
+ "input": input_data['args'] if hasattr(input_data, 'args') else input_data,
407
+ "output": output_data.output_response if output_data else None,
408
+ "memory_used": memory_used
409
+ },
410
+ "network_calls": self.component_network_calls.get(component_id, []),
411
+ "interactions": self.component_user_interaction.get(component_id, [])
412
+ }
413
+
414
+ if self.gt:
415
+ component["data"]["gt"] = self.gt
416
+
417
+ return component
418
+
419
+ def start_component(self, component_id):
420
+ """Start tracking network calls for a component"""
421
+ self.component_network_calls[component_id] = []
422
+ self.current_component_id = component_id
423
+
424
+ def end_component(self, component_id):
425
+ """Stop tracking network calls for a component"""
426
+ self.current_component_id = None
427
+
428
+
429
+ async def trace_llm_call(self, original_func, *args, **kwargs):
430
+ """Trace an LLM API call"""
431
+ if not self.is_active:
432
+ return await original_func(*args, **kwargs)
433
+
434
+ start_time = datetime.now().astimezone()
435
+ start_memory = psutil.Process().memory_info().rss
436
+ component_id = str(uuid.uuid4())
437
+ hash_id = generate_unique_hash_simple(original_func)
438
+
439
+ # Start tracking network calls for this component
440
+ self.start_component(component_id)
441
+
442
+ try:
443
+ # Execute the LLM call
444
+ result = await original_func(*args, **kwargs)
445
+
446
+ # Calculate resource usage
447
+ end_time = datetime.now().astimezone()
448
+ end_memory = psutil.Process().memory_info().rss
449
+ memory_used = max(0, end_memory - start_memory)
450
+
451
+ # Extract token usage and calculate cost
452
+ token_usage = self._extract_token_usage(result)
453
+ model_name = self._extract_model_name(args, kwargs, result)
454
+ cost = self._calculate_cost(token_usage, model_name)
455
+ parameters = self._extract_parameters(kwargs)
456
+
457
+ # End tracking network calls for this component
458
+ self.end_component(component_id)
459
+
460
+ name = self.current_llm_call_name.get()
461
+ if name is None:
462
+ name = original_func.__name__
463
+
464
+ # Create input data with ground truth
465
+ input_data = self._extract_input_data(args, kwargs, result)
466
+
467
+ # Create LLM component
468
+ llm_component = self.create_llm_component(
469
+ component_id=component_id,
470
+ hash_id=hash_id,
471
+ name=name,
472
+ llm_type=model_name,
473
+ version="1.0.0",
474
+ memory_used=memory_used,
475
+ start_time=start_time,
476
+ end_time=end_time,
477
+ input_data=input_data,
478
+ output_data=extract_llm_output(result),
479
+ cost=cost,
480
+ usage=token_usage,
481
+ parameters=parameters
482
+ )
483
+
484
+ # self.add_component(llm_component)
485
+ self.llm_data = llm_component
486
+ return result
487
+
488
+ except Exception as e:
489
+ error_component = {
490
+ "code": 500,
491
+ "type": type(e).__name__,
492
+ "message": str(e),
493
+ "details": {}
494
+ }
495
+
496
+ # End tracking network calls for this component
497
+ self.end_component(component_id)
498
+
499
+ end_time = datetime.now().astimezone()
500
+
501
+ name = self.current_llm_call_name.get()
502
+ if name is None:
503
+ name = original_func.__name__
504
+
505
+ llm_component = self.create_llm_component(
506
+ component_id=component_id,
507
+ hash_id=hash_id,
508
+ name=name,
509
+ llm_type="unknown",
510
+ version="1.0.0",
511
+ memory_used=0,
512
+ start_time=start_time,
513
+ end_time=end_time,
514
+ input_data=self._extract_input_data(args, kwargs, None),
515
+ output_data=None,
516
+ error=error_component
517
+ )
518
+
519
+ self.add_component(llm_component)
520
+ raise
521
+
522
+ def trace_llm_call_sync(self, original_func, *args, **kwargs):
523
+ """Sync version of trace_llm_call"""
524
+ if not self.is_active:
525
+ if asyncio.iscoroutinefunction(original_func):
526
+ return asyncio.run(original_func(*args, **kwargs))
527
+ return original_func(*args, **kwargs)
528
+
529
+ start_time = datetime.now().astimezone()
530
+ component_id = str(uuid.uuid4())
531
+ hash_id = generate_unique_hash_simple(original_func)
532
+
533
+ # Start tracking network calls for this component
534
+ self.start_component(component_id)
535
+
536
+ # Calculate resource usage
537
+ end_time = datetime.now().astimezone()
538
+ start_memory = psutil.Process().memory_info().rss
539
+
540
+ try:
541
+ # Execute the function
542
+ if asyncio.iscoroutinefunction(original_func):
543
+ result = asyncio.run(original_func(*args, **kwargs))
544
+ else:
545
+ result = original_func(*args, **kwargs)
546
+
547
+ end_memory = psutil.Process().memory_info().rss
548
+ memory_used = max(0, end_memory - start_memory)
549
+
550
+ # Extract token usage and calculate cost
551
+ token_usage = self._extract_token_usage(result)
552
+ model_name = self._extract_model_name(args, kwargs, result)
553
+ cost = self._calculate_cost(token_usage, model_name)
554
+ parameters = self._extract_parameters(kwargs)
555
+
556
+ # End tracking network calls for this component
557
+ self.end_component(component_id)
558
+
559
+ name = self.current_llm_call_name.get()
560
+ if name is None:
561
+ name = original_func.__name__
562
+
563
+ # Create input data with ground truth
564
+ input_data = self._extract_input_data(args, kwargs, result)
565
+
566
+ # Create LLM component
567
+ llm_component = self.create_llm_component(
568
+ component_id=component_id,
569
+ hash_id=hash_id,
570
+ name=name,
571
+ llm_type=model_name,
572
+ version="1.0.0",
573
+ memory_used=memory_used,
574
+ start_time=start_time,
575
+ end_time=end_time,
576
+ input_data=input_data,
577
+ output_data=extract_llm_output(result),
578
+ cost=cost,
579
+ usage=token_usage,
580
+ parameters=parameters
581
+ )
582
+
583
+ self.add_component(llm_component)
584
+ return result
585
+
586
+ except Exception as e:
587
+ error_component = {
588
+ "code": 500,
589
+ "type": type(e).__name__,
590
+ "message": str(e),
591
+ "details": {}
592
+ }
593
+
594
+ # End tracking network calls for this component
595
+ self.end_component(component_id)
596
+
597
+ end_time = datetime.now().astimezone()
598
+
599
+ name = self.current_llm_call_name.get()
600
+ if name is None:
601
+ name = original_func.__name__
602
+
603
+ end_memory = psutil.Process().memory_info().rss
604
+ memory_used = max(0, end_memory - start_memory)
605
+
606
+ llm_component = self.create_llm_component(
607
+ component_id=component_id,
608
+ hash_id=hash_id,
609
+ name=name,
610
+ llm_type="unknown",
611
+ version="1.0.0",
612
+ memory_used=memory_used,
613
+ start_time=start_time,
614
+ end_time=end_time,
615
+ input_data=self._extract_input_data(args, kwargs, None),
616
+ output_data=None,
617
+ error=error_component
618
+ )
619
+
620
+ self.add_component(llm_component)
621
+ raise
622
+
623
+ def trace_llm(self, name: str = None):
624
+ def decorator(func):
625
+ @self.file_tracker.trace_decorator
626
+ @functools.wraps(func)
627
+ async def async_wrapper(*args, **kwargs):
628
+ self.gt = kwargs.get('gt', None) if kwargs else None
629
+ if not self.is_active:
630
+ return await func(*args, **kwargs)
631
+
632
+ hash_id = generate_unique_hash_simple(func)
633
+ component_id = str(uuid.uuid4())
634
+ parent_agent_id = self.current_agent_id.get()
635
+ self.start_component(component_id)
636
+
637
+ start_time = datetime.now()
638
+ error_info = None
639
+ result = None
640
+
641
+ try:
642
+ result = await func(*args, **kwargs)
643
+ return result
644
+ except Exception as e:
645
+ error_info = {
646
+ "error": {
647
+ "type": type(e).__name__,
648
+ "message": str(e),
649
+ "traceback": traceback.format_exc(),
650
+ "timestamp": datetime.now().isoformat()
651
+ }
652
+ }
653
+ raise
654
+ finally:
655
+
656
+ llm_component = self.llm_data
657
+
658
+ if error_info:
659
+ llm_component["error"] = error_info["error"]
660
+
661
+ if parent_agent_id:
662
+ children = self.agent_children.get()
663
+ children.append(llm_component)
664
+ self.agent_children.set(children)
665
+ else:
666
+ self.add_component(llm_component)
667
+
668
+ self.end_component(component_id)
669
+
670
+ @self.file_tracker.trace_decorator
671
+ @functools.wraps(func)
672
+ def sync_wrapper(*args, **kwargs):
673
+ self.gt = kwargs.get('gt', None) if kwargs else None
674
+ if not self.is_active:
675
+ return func(*args, **kwargs)
676
+
677
+ hash_id = generate_unique_hash_simple(func)
678
+
679
+ component_id = str(uuid.uuid4())
680
+ parent_agent_id = self.current_agent_id.get()
681
+ self.start_component(component_id)
682
+
683
+ start_time = datetime.now()
684
+ error_info = None
685
+ result = None
686
+
687
+ try:
688
+ result = func(*args, **kwargs)
689
+ return result
690
+ except Exception as e:
691
+ error_info = {
692
+ "error": {
693
+ "type": type(e).__name__,
694
+ "message": str(e),
695
+ "traceback": traceback.format_exc(),
696
+ "timestamp": datetime.now().isoformat()
697
+ }
698
+ }
699
+ raise
700
+ finally:
701
+
702
+ llm_component = self.llm_data
703
+
704
+ if error_info:
705
+ llm_component["error"] = error_info["error"]
706
+
707
+ if parent_agent_id:
708
+ children = self.agent_children.get()
709
+ children.append(llm_component)
710
+ self.agent_children.set(children)
711
+ else:
712
+ self.add_component(llm_component)
713
+
714
+ self.end_component(component_id)
715
+
716
+ return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
717
+ return decorator
718
+
719
+ def unpatch_llm_calls(self):
720
+ # Remove all patches
721
+ for obj, method_name, original_method in self.patches:
722
+ try:
723
+ setattr(obj, method_name, original_method)
724
+ except Exception as e:
725
+ print(f"Error unpatching {method_name}: {str(e)}")
726
+ self.patches = []
727
+
728
+ def _sanitize_api_keys(self, data):
729
+ """Remove sensitive information from data"""
730
+ if isinstance(data, dict):
731
+ return {k: self._sanitize_api_keys(v) for k, v in data.items()
732
+ if not any(sensitive in k.lower() for sensitive in ['key', 'token', 'secret', 'password'])}
733
+ elif isinstance(data, list):
734
+ return [self._sanitize_api_keys(item) for item in data]
735
+ elif isinstance(data, tuple):
736
+ return tuple(self._sanitize_api_keys(item) for item in data)
737
+ return data
738
+
739
+ def _sanitize_input(self, args, kwargs):
740
+ """Convert input arguments to text format.
741
+
742
+ Args:
743
+ args: Input arguments that may contain nested dictionaries
744
+
745
+ Returns:
746
+ str: Text representation of the input arguments
747
+ """
748
+ if isinstance(args, dict):
749
+ return str({k: self._sanitize_input(v, {}) for k, v in args.items()})
750
+ elif isinstance(args, (list, tuple)):
751
+ return str([self._sanitize_input(item, {}) for item in args])
752
+ return str(args)
753
+
754
+ def extract_llm_output(result):
755
+ """Extract output from LLM response"""
756
+ class OutputResponse:
757
+ def __init__(self, output_response):
758
+ self.output_response = output_response
759
+
760
+ # Handle coroutines
761
+ if asyncio.iscoroutine(result):
762
+ # For sync context, run the coroutine
763
+ if not asyncio.get_event_loop().is_running():
764
+ result = asyncio.run(result)
765
+ else:
766
+ # We're in an async context, but this function is called synchronously
767
+ # Return a placeholder and let the caller handle the coroutine
768
+ return OutputResponse("Coroutine result pending")
769
+
770
+ # Handle Google GenerativeAI format
771
+ if hasattr(result, "result"):
772
+ candidates = getattr(result.result, "candidates", [])
773
+ output = []
774
+ for candidate in candidates:
775
+ content = getattr(candidate, "content", None)
776
+ if content and hasattr(content, "parts"):
777
+ for part in content.parts:
778
+ if hasattr(part, "text"):
779
+ output.append({
780
+ "content": part.text,
781
+ "role": getattr(content, "role", "assistant"),
782
+ "finish_reason": getattr(candidate, "finish_reason", None)
783
+ })
784
+ return OutputResponse(output)
785
+
786
+ # Handle Vertex AI format
787
+ if hasattr(result, "text"):
788
+ return OutputResponse([{
789
+ "content": result.text,
790
+ "role": "assistant"
791
+ }])
792
+
793
+ # Handle OpenAI format
794
+ if hasattr(result, "choices"):
795
+ return OutputResponse([{
796
+ "content": choice.message.content,
797
+ "role": choice.message.role
798
+ } for choice in result.choices])
799
+
800
+ # Handle Anthropic format
801
+ if hasattr(result, "completion"):
802
+ return OutputResponse([{
803
+ "content": result.completion,
804
+ "role": "assistant"
805
+ }])
806
+
807
+ # Default case
808
+ return OutputResponse(str(result))