ragaai-catalyst 2.0.7.2__py3-none-any.whl → 2.0.7.2b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. ragaai_catalyst/evaluation.py +107 -153
  2. ragaai_catalyst/tracers/agentic_tracing/Untitled-1.json +660 -0
  3. ragaai_catalyst/tracers/agentic_tracing/__init__.py +3 -0
  4. ragaai_catalyst/tracers/agentic_tracing/agent_tracer.py +311 -0
  5. ragaai_catalyst/tracers/agentic_tracing/agentic_tracing.py +212 -0
  6. ragaai_catalyst/tracers/agentic_tracing/base.py +270 -0
  7. ragaai_catalyst/tracers/agentic_tracing/data_structure.py +239 -0
  8. ragaai_catalyst/tracers/agentic_tracing/llm_tracer.py +906 -0
  9. ragaai_catalyst/tracers/agentic_tracing/network_tracer.py +286 -0
  10. ragaai_catalyst/tracers/agentic_tracing/sample.py +197 -0
  11. ragaai_catalyst/tracers/agentic_tracing/tool_tracer.py +235 -0
  12. ragaai_catalyst/tracers/agentic_tracing/unique_decorator.py +221 -0
  13. ragaai_catalyst/tracers/agentic_tracing/unique_decorator_test.py +172 -0
  14. ragaai_catalyst/tracers/agentic_tracing/user_interaction_tracer.py +67 -0
  15. ragaai_catalyst/tracers/agentic_tracing/utils/__init__.py +3 -0
  16. ragaai_catalyst/tracers/agentic_tracing/utils/api_utils.py +18 -0
  17. ragaai_catalyst/tracers/agentic_tracing/utils/data_classes.py +61 -0
  18. ragaai_catalyst/tracers/agentic_tracing/utils/generic.py +32 -0
  19. ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +181 -0
  20. ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +5946 -0
  21. ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +74 -0
  22. ragaai_catalyst/tracers/tracer.py +26 -4
  23. ragaai_catalyst/tracers/upload_traces.py +127 -0
  24. ragaai_catalyst-2.0.7.2b0.dist-info/METADATA +39 -0
  25. ragaai_catalyst-2.0.7.2b0.dist-info/RECORD +50 -0
  26. ragaai_catalyst-2.0.7.2.dist-info/METADATA +0 -386
  27. ragaai_catalyst-2.0.7.2.dist-info/RECORD +0 -29
  28. {ragaai_catalyst-2.0.7.2.dist-info → ragaai_catalyst-2.0.7.2b0.dist-info}/WHEEL +0 -0
  29. {ragaai_catalyst-2.0.7.2.dist-info → ragaai_catalyst-2.0.7.2b0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,906 @@
1
+ from typing import Optional, Any, Dict, List
2
+ import asyncio
3
+ import psutil
4
+ import json
5
+ import wrapt
6
+ import functools
7
+ from datetime import datetime
8
+ import uuid
9
+ import os
10
+ import contextvars
11
+ import sys
12
+ import gc
13
+
14
+ from .unique_decorator import mydecorator
15
+ from .utils.trace_utils import calculate_cost, load_model_costs
16
+ from .utils.llm_utils import extract_llm_output
17
+
18
+
19
+ class LLMTracerMixin:
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+ self.patches = []
23
+ try:
24
+ self.model_costs = load_model_costs()
25
+ except Exception as e:
26
+ # If model costs can't be loaded, use default costs
27
+ self.model_costs = {
28
+ "default": {
29
+ "input_cost_per_token": 0.00002,
30
+ "output_cost_per_token": 0.00002
31
+ }
32
+ }
33
+ self.current_llm_call_name = contextvars.ContextVar("llm_call_name", default=None)
34
+ self.component_network_calls = {}
35
+ self.current_component_id = None
36
+ self.total_tokens = 0
37
+ self.total_cost = 0.0
38
+ # Apply decorator to trace_llm_call method
39
+ self.trace_llm_call = mydecorator(self.trace_llm_call)
40
+
41
+ def instrument_llm_calls(self):
42
+ # Handle modules that are already imported
43
+ import sys
44
+
45
+ if "vertexai" in sys.modules:
46
+ self.patch_vertex_ai_methods(sys.modules["vertexai"])
47
+ if "vertexai.generative_models" in sys.modules:
48
+ self.patch_vertex_ai_methods(sys.modules["vertexai.generative_models"])
49
+
50
+ if "openai" in sys.modules:
51
+ self.patch_openai_methods(sys.modules["openai"])
52
+ if "litellm" in sys.modules:
53
+ self.patch_litellm_methods(sys.modules["litellm"])
54
+ if "anthropic" in sys.modules:
55
+ self.patch_anthropic_methods(sys.modules["anthropic"])
56
+ if "google.generativeai" in sys.modules:
57
+ self.patch_google_genai_methods(sys.modules["google.generativeai"])
58
+ if "langchain_google_vertexai" in sys.modules:
59
+ self.patch_langchain_google_methods(sys.modules["langchain_google_vertexai"])
60
+ if "langchain_google_genai" in sys.modules:
61
+ self.patch_langchain_google_methods(sys.modules["langchain_google_genai"])
62
+
63
+ # Register hooks for future imports
64
+ wrapt.register_post_import_hook(self.patch_vertex_ai_methods, "vertexai")
65
+ wrapt.register_post_import_hook(self.patch_vertex_ai_methods, "vertexai.generative_models")
66
+ wrapt.register_post_import_hook(self.patch_openai_methods, "openai")
67
+ wrapt.register_post_import_hook(self.patch_litellm_methods, "litellm")
68
+ wrapt.register_post_import_hook(self.patch_anthropic_methods, "anthropic")
69
+ wrapt.register_post_import_hook(self.patch_google_genai_methods, "google.generativeai")
70
+
71
+ # Add hooks for LangChain integrations
72
+ wrapt.register_post_import_hook(self.patch_langchain_google_methods, "langchain_google_vertexai")
73
+ wrapt.register_post_import_hook(self.patch_langchain_google_methods, "langchain_google_genai")
74
+
75
+ def patch_openai_methods(self, module):
76
+ try:
77
+ if hasattr(module, "OpenAI"):
78
+ client_class = getattr(module, "OpenAI")
79
+ self.wrap_openai_client_methods(client_class)
80
+ if hasattr(module, "AsyncOpenAI"):
81
+ async_client_class = getattr(module, "AsyncOpenAI")
82
+ self.wrap_openai_client_methods(async_client_class)
83
+ except Exception as e:
84
+ # Log the error but continue execution
85
+ print(f"Warning: Failed to patch OpenAI methods: {str(e)}")
86
+
87
+ def patch_anthropic_methods(self, module):
88
+ if hasattr(module, "Anthropic"):
89
+ client_class = getattr(module, "Anthropic")
90
+ self.wrap_anthropic_client_methods(client_class)
91
+
92
+ def patch_google_genai_methods(self, module):
93
+ # Patch direct Google GenerativeAI usage
94
+ if hasattr(module, "GenerativeModel"):
95
+ model_class = getattr(module, "GenerativeModel")
96
+ self.wrap_genai_model_methods(model_class)
97
+
98
+ # Patch LangChain integration
99
+ if hasattr(module, "ChatGoogleGenerativeAI"):
100
+ chat_class = getattr(module, "ChatGoogleGenerativeAI")
101
+ # Wrap invoke method to capture messages
102
+ original_invoke = chat_class.invoke
103
+
104
+ def patched_invoke(self, messages, *args, **kwargs):
105
+ # Store messages in the instance for later use
106
+ self._last_messages = messages
107
+ return original_invoke(self, messages, *args, **kwargs)
108
+
109
+ chat_class.invoke = patched_invoke
110
+
111
+ # LangChain v0.2+ uses invoke/ainvoke
112
+ self.wrap_method(chat_class, "_generate")
113
+ if hasattr(chat_class, "_agenerate"):
114
+ self.wrap_method(chat_class, "_agenerate")
115
+ # Fallback for completion methods
116
+ if hasattr(chat_class, "complete"):
117
+ self.wrap_method(chat_class, "complete")
118
+ if hasattr(chat_class, "acomplete"):
119
+ self.wrap_method(chat_class, "acomplete")
120
+
121
+ def patch_vertex_ai_methods(self, module):
122
+ # Patch the GenerativeModel class
123
+ if hasattr(module, "generative_models"):
124
+ gen_models = getattr(module, "generative_models")
125
+ if hasattr(gen_models, "GenerativeModel"):
126
+ model_class = getattr(gen_models, "GenerativeModel")
127
+ self.wrap_vertex_model_methods(model_class)
128
+
129
+ # Also patch the class directly if available
130
+ if hasattr(module, "GenerativeModel"):
131
+ model_class = getattr(module, "GenerativeModel")
132
+ self.wrap_vertex_model_methods(model_class)
133
+
134
+ def wrap_vertex_model_methods(self, model_class):
135
+ # Patch both sync and async methods
136
+ self.wrap_method(model_class, "generate_content")
137
+ if hasattr(model_class, "generate_content_async"):
138
+ self.wrap_method(model_class, "generate_content_async")
139
+
140
+ def patch_litellm_methods(self, module):
141
+ self.wrap_method(module, "completion")
142
+ self.wrap_method(module, "acompletion")
143
+
144
+ def patch_langchain_google_methods(self, module):
145
+ """Patch LangChain's Google integration methods"""
146
+ if hasattr(module, "ChatVertexAI"):
147
+ chat_class = getattr(module, "ChatVertexAI")
148
+ # LangChain v0.2+ uses invoke/ainvoke
149
+ self.wrap_method(chat_class, "_generate")
150
+ if hasattr(chat_class, "_agenerate"):
151
+ self.wrap_method(chat_class, "_agenerate")
152
+ # Fallback for completion methods
153
+ if hasattr(chat_class, "complete"):
154
+ self.wrap_method(chat_class, "complete")
155
+ if hasattr(chat_class, "acomplete"):
156
+ self.wrap_method(chat_class, "acomplete")
157
+
158
+ if hasattr(module, "ChatGoogleGenerativeAI"):
159
+ chat_class = getattr(module, "ChatGoogleGenerativeAI")
160
+ # LangChain v0.2+ uses invoke/ainvoke
161
+ self.wrap_method(chat_class, "_generate")
162
+ if hasattr(chat_class, "_agenerate"):
163
+ self.wrap_method(chat_class, "_agenerate")
164
+ # Fallback for completion methods
165
+ if hasattr(chat_class, "complete"):
166
+ self.wrap_method(chat_class, "complete")
167
+ if hasattr(chat_class, "acomplete"):
168
+ self.wrap_method(chat_class, "acomplete")
169
+
170
+ def wrap_openai_client_methods(self, client_class):
171
+ original_init = client_class.__init__
172
+
173
+ @functools.wraps(original_init)
174
+ def patched_init(client_self, *args, **kwargs):
175
+ original_init(client_self, *args, **kwargs)
176
+ self.wrap_method(client_self.chat.completions, "create")
177
+ if hasattr(client_self.chat.completions, "acreate"):
178
+ self.wrap_method(client_self.chat.completions, "acreate")
179
+
180
+ setattr(client_class, "__init__", patched_init)
181
+
182
+ def wrap_anthropic_client_methods(self, client_class):
183
+ original_init = client_class.__init__
184
+
185
+ @functools.wraps(original_init)
186
+ def patched_init(client_self, *args, **kwargs):
187
+ original_init(client_self, *args, **kwargs)
188
+ self.wrap_method(client_self.messages, "create")
189
+ if hasattr(client_self.messages, "acreate"):
190
+ self.wrap_method(client_self.messages, "acreate")
191
+
192
+ setattr(client_class, "__init__", patched_init)
193
+
194
+ def wrap_genai_model_methods(self, model_class):
195
+ original_init = model_class.__init__
196
+
197
+ @functools.wraps(original_init)
198
+ def patched_init(model_self, *args, **kwargs):
199
+ original_init(model_self, *args, **kwargs)
200
+ self.wrap_method(model_self, "generate_content")
201
+ if hasattr(model_self, "generate_content_async"):
202
+ self.wrap_method(model_self, "generate_content_async")
203
+
204
+ setattr(model_class, "__init__", patched_init)
205
+
206
+ def wrap_method(self, obj, method_name):
207
+ """
208
+ Wrap a method with tracing functionality.
209
+ Works for both class methods and instance methods.
210
+ """
211
+ # If obj is a class, we need to patch both the class and any existing instances
212
+ if isinstance(obj, type):
213
+ # Store the original class method
214
+ original_method = getattr(obj, method_name)
215
+
216
+ @wrapt.decorator
217
+ def wrapper(wrapped, instance, args, kwargs):
218
+ if asyncio.iscoroutinefunction(wrapped):
219
+ return self.trace_llm_call(wrapped, *args, **kwargs)
220
+ return self.trace_llm_call_sync(wrapped, *args, **kwargs)
221
+
222
+ # Wrap the class method
223
+ wrapped_method = wrapper(original_method)
224
+ setattr(obj, method_name, wrapped_method)
225
+ self.patches.append((obj, method_name, original_method))
226
+
227
+ else:
228
+ # For instance methods
229
+ original_method = getattr(obj, method_name)
230
+
231
+ @wrapt.decorator
232
+ def wrapper(wrapped, instance, args, kwargs):
233
+ if asyncio.iscoroutinefunction(wrapped):
234
+ return self.trace_llm_call(wrapped, *args, **kwargs)
235
+ return self.trace_llm_call_sync(wrapped, *args, **kwargs)
236
+
237
+ wrapped_method = wrapper(original_method)
238
+ setattr(obj, method_name, wrapped_method)
239
+ self.patches.append((obj, method_name, original_method))
240
+
241
+ def _extract_model_name(self, kwargs):
242
+ """Extract model name from kwargs or result"""
243
+ # First try direct model parameter
244
+ model = kwargs.get("model", "")
245
+
246
+ if not model:
247
+ # Try to get from instance
248
+ instance = kwargs.get("self", None)
249
+ if instance:
250
+ # Try model_name first (Google format)
251
+ if hasattr(instance, "model_name"):
252
+ model = instance.model_name
253
+ # Try model attribute
254
+ elif hasattr(instance, "model"):
255
+ model = instance.model
256
+
257
+ # Normalize Google model names
258
+ if model and isinstance(model, str):
259
+ model = model.lower()
260
+ if "gemini-1.5-flash" in model:
261
+ return "gemini-1.5-flash"
262
+ if "gemini-1.5-pro" in model:
263
+ return "gemini-1.5-pro"
264
+ if "gemini-pro" in model:
265
+ return "gemini-pro"
266
+
267
+ return model or "default"
268
+
269
+ def _extract_parameters(self, kwargs, result=None):
270
+ """Extract parameters from kwargs or result"""
271
+ params = {
272
+ "temperature": kwargs.get("temperature", getattr(result, "temperature", 0.7)),
273
+ "top_p": kwargs.get("top_p", getattr(result, "top_p", 1.0)),
274
+ "max_tokens": kwargs.get("max_tokens", getattr(result, "max_tokens", 512))
275
+ }
276
+
277
+ # Add Google AI specific parameters if available
278
+ if hasattr(kwargs.get("self", None), "generation_config"):
279
+ gen_config = kwargs["self"].generation_config
280
+ params.update({
281
+ "candidate_count": getattr(gen_config, "candidate_count", 1),
282
+ "stop_sequences": getattr(gen_config, "stop_sequences", []),
283
+ "top_k": getattr(gen_config, "top_k", 40)
284
+ })
285
+
286
+ return params
287
+
288
+ def _extract_token_usage(self, result):
289
+ """Extract token usage from result"""
290
+ # Handle coroutines
291
+ if asyncio.iscoroutine(result):
292
+ result = asyncio.run(result)
293
+
294
+ # Handle standard OpenAI/Anthropic format
295
+ if hasattr(result, "usage"):
296
+ usage = result.usage
297
+ return {
298
+ "prompt_tokens": getattr(usage, "prompt_tokens", 0),
299
+ "completion_tokens": getattr(usage, "completion_tokens", 0),
300
+ "total_tokens": getattr(usage, "total_tokens", 0)
301
+ }
302
+
303
+ # Handle Google GenerativeAI format with usage_metadata
304
+ if hasattr(result, "usage_metadata"):
305
+ metadata = result.usage_metadata
306
+ return {
307
+ "prompt_tokens": getattr(metadata, "prompt_token_count", 0),
308
+ "completion_tokens": getattr(metadata, "candidates_token_count", 0),
309
+ "total_tokens": getattr(metadata, "total_token_count", 0)
310
+ }
311
+
312
+ # Handle Vertex AI format
313
+ if hasattr(result, "text"):
314
+ # For LangChain ChatVertexAI
315
+ total_tokens = getattr(result, "token_count", 0)
316
+ if not total_tokens and hasattr(result, "_raw_response"):
317
+ # Try to get from raw response
318
+ total_tokens = getattr(result._raw_response, "token_count", 0)
319
+ return {
320
+ "prompt_tokens": 0, # Vertex AI doesn't provide this breakdown
321
+ "completion_tokens": total_tokens,
322
+ "total_tokens": total_tokens
323
+ }
324
+
325
+ return {
326
+ "prompt_tokens": 0,
327
+ "completion_tokens": 0,
328
+ "total_tokens": 0
329
+ }
330
+
331
+ def _extract_input_data(self, kwargs, result):
332
+ """Extract input data from kwargs and result"""
333
+
334
+ # For Vertex AI GenerationResponse
335
+ if hasattr(result, 'candidates') and hasattr(result, 'usage_metadata'):
336
+ # Extract generation config
337
+ generation_config = kwargs.get('generation_config', {})
338
+ config_dict = {}
339
+ if hasattr(generation_config, 'temperature'):
340
+ config_dict['temperature'] = generation_config.temperature
341
+ if hasattr(generation_config, 'top_p'):
342
+ config_dict['top_p'] = generation_config.top_p
343
+ if hasattr(generation_config, 'max_output_tokens'):
344
+ config_dict['max_tokens'] = generation_config.max_output_tokens
345
+ if hasattr(generation_config, 'candidate_count'):
346
+ config_dict['n'] = generation_config.candidate_count
347
+
348
+ return {
349
+ "prompt": kwargs.get('contents', ''),
350
+ "model": "gemini-1.5-flash-002",
351
+ **config_dict
352
+ }
353
+
354
+ # For standard OpenAI format
355
+ messages = kwargs.get("messages", [])
356
+ if messages:
357
+ return {
358
+ "messages": messages,
359
+ "model": kwargs.get("model", "unknown"),
360
+ "temperature": kwargs.get("temperature", 0.7),
361
+ "max_tokens": kwargs.get("max_tokens", None),
362
+ "top_p": kwargs.get("top_p", None),
363
+ "frequency_penalty": kwargs.get("frequency_penalty", None),
364
+ "presence_penalty": kwargs.get("presence_penalty", None)
365
+ }
366
+
367
+ # For text completion format
368
+ if "prompt" in kwargs:
369
+ return {
370
+ "prompt": kwargs["prompt"],
371
+ "model": kwargs.get("model", "unknown"),
372
+ "temperature": kwargs.get("temperature", 0.7),
373
+ "max_tokens": kwargs.get("max_tokens", None),
374
+ "top_p": kwargs.get("top_p", None),
375
+ "frequency_penalty": kwargs.get("frequency_penalty", None),
376
+ "presence_penalty": kwargs.get("presence_penalty", None)
377
+ }
378
+
379
+ # For any other case, try to extract from kwargs
380
+ if "contents" in kwargs:
381
+ return {
382
+ "prompt": kwargs["contents"],
383
+ "model": kwargs.get("model", "unknown"),
384
+ "temperature": kwargs.get("temperature", 0.7),
385
+ "max_tokens": kwargs.get("max_tokens", None),
386
+ "top_p": kwargs.get("top_p", None)
387
+ }
388
+
389
+ print("No input data found")
390
+ return {}
391
+
392
+ def _calculate_cost(self, token_usage, model_name):
393
+ """Calculate cost based on token usage and model"""
394
+ if not isinstance(token_usage, dict):
395
+ token_usage = {
396
+ "prompt_tokens": 0,
397
+ "completion_tokens": 0,
398
+ "total_tokens": token_usage if isinstance(token_usage, (int, float)) else 0
399
+ }
400
+
401
+ # Get model costs, defaulting to Vertex AI PaLM2 costs if unknown
402
+ model_cost = self.model_costs.get(model_name, {
403
+ "input_cost_per_token": 0.0005, # $0.0005 per 1K input tokens
404
+ "output_cost_per_token": 0.0005 # $0.0005 per 1K output tokens
405
+ })
406
+
407
+ # Calculate costs per 1K tokens
408
+ input_cost = (token_usage.get("prompt_tokens", 0) / 1000.0) * model_cost.get("input_cost_per_token", 0.0005)
409
+ output_cost = (token_usage.get("completion_tokens", 0) / 1000.0) * model_cost.get("output_cost_per_token", 0.0005)
410
+ total_cost = input_cost + output_cost
411
+
412
+ return {
413
+ "input_cost": round(input_cost, 6),
414
+ "output_cost": round(output_cost, 6),
415
+ "total_cost": round(total_cost, 6)
416
+ }
417
+
418
+ def create_llm_component(self, **kwargs):
419
+ """Create an LLM component according to the data structure"""
420
+ start_time = kwargs["start_time"]
421
+
422
+ # Ensure cost and usage are dictionaries
423
+ cost = kwargs.get("cost", {})
424
+ if not isinstance(cost, dict):
425
+ cost = {"total_cost": cost}
426
+
427
+ usage = kwargs.get("usage", {})
428
+ if not isinstance(usage, dict):
429
+ usage = {"total_tokens": usage}
430
+
431
+ component = {
432
+ "id": kwargs["component_id"],
433
+ "hash_id": kwargs["hash_id"],
434
+ "source_hash_id": None,
435
+ "type": "llm",
436
+ "name": kwargs["name"],
437
+ "start_time": start_time.isoformat(),
438
+ "end_time": kwargs["end_time"].isoformat(),
439
+ "error": kwargs.get("error"),
440
+ "parent_id": self.current_agent_id.get(),
441
+ "info": {
442
+ "llm_type": kwargs.get("llm_type", "unknown"),
443
+ "version": kwargs.get("version", "1.0.0"),
444
+ "memory_used": kwargs.get("memory_used", 0),
445
+ "cost": cost,
446
+ "tokens": usage
447
+ },
448
+ "data": {
449
+ "input": kwargs.get("input_data"),
450
+ "output": kwargs.get("output_data"),
451
+ "memory_used": kwargs.get("memory_used", 0)
452
+ },
453
+ "network_calls": self.component_network_calls.get(kwargs["component_id"], []),
454
+ "interactions": [
455
+ {
456
+ "id": f"int_{uuid.uuid4()}",
457
+ "interaction_type": "input",
458
+ "timestamp": start_time.isoformat(),
459
+ "content": kwargs.get("input_data")
460
+ },
461
+ {
462
+ "id": f"int_{uuid.uuid4()}",
463
+ "interaction_type": "output",
464
+ "timestamp": kwargs["end_time"].isoformat(),
465
+ "content": kwargs.get("output_data")
466
+ }
467
+ ]
468
+ }
469
+ return component
470
+
471
+ def start_component(self, component_id):
472
+ """Start tracking network calls for a component"""
473
+ self.component_network_calls[component_id] = []
474
+ self.current_component_id = component_id
475
+
476
+ def end_component(self, component_id):
477
+ """Stop tracking network calls for a component"""
478
+ self.current_component_id = None
479
+
480
+
481
+ async def trace_llm_call(self, original_func, *args, **kwargs):
482
+ """Trace an LLM API call"""
483
+ if not self.is_active:
484
+ if asyncio.iscoroutinefunction(original_func):
485
+ return await original_func(*args, **kwargs)
486
+ return original_func(*args, **kwargs)
487
+
488
+ start_time = datetime.now().astimezone()
489
+ start_memory = psutil.Process().memory_info().rss
490
+ component_id = str(uuid.uuid4())
491
+ hash_id = self.trace_llm_call.hash_id
492
+
493
+ # Start tracking network calls for this component
494
+ self.start_component(component_id)
495
+
496
+ try:
497
+ # Execute the LLM call
498
+ result = None
499
+ if asyncio.iscoroutinefunction(original_func):
500
+ result = await original_func(*args, **kwargs)
501
+ else:
502
+ result = original_func(*args, **kwargs)
503
+
504
+ # If result is a coroutine, await it
505
+ if asyncio.iscoroutine(result):
506
+ result = await result
507
+
508
+ # Calculate resource usage
509
+ end_time = datetime.now().astimezone()
510
+ end_memory = psutil.Process().memory_info().rss
511
+ memory_used = max(0, end_memory - start_memory)
512
+
513
+ # Extract token usage and calculate cost
514
+ token_usage = await self._extract_token_usage(result)
515
+ model_name = self._extract_model_name(kwargs)
516
+ cost = self._calculate_cost(token_usage, model_name)
517
+
518
+ # End tracking network calls for this component
519
+ self.end_component(component_id)
520
+
521
+ # Create LLM component
522
+ llm_component = self.create_llm_component(
523
+ component_id=component_id,
524
+ hash_id=hash_id,
525
+ name=self.current_llm_call_name.get(),
526
+ llm_type=model_name,
527
+ version="1.0.0",
528
+ memory_used=memory_used,
529
+ start_time=start_time,
530
+ end_time=end_time,
531
+ input_data=self._extract_input_data(kwargs, result),
532
+ output_data=extract_llm_output(result),
533
+ cost=cost,
534
+ usage=token_usage
535
+ )
536
+
537
+ self.add_component(llm_component)
538
+ return result
539
+
540
+ except Exception as e:
541
+ error_component = {
542
+ "code": 500,
543
+ "type": type(e).__name__,
544
+ "message": str(e),
545
+ "details": {}
546
+ }
547
+
548
+ # End tracking network calls for this component
549
+ self.end_component(component_id)
550
+
551
+ end_time = datetime.now().astimezone()
552
+
553
+ llm_component = self.create_llm_component(
554
+ component_id=component_id,
555
+ hash_id=hash_id,
556
+ name=self.current_llm_call_name.get(),
557
+ llm_type="unknown",
558
+ version="1.0.0",
559
+ memory_used=0,
560
+ start_time=start_time,
561
+ end_time=end_time,
562
+ input_data=self._extract_input_data(kwargs, None),
563
+ output_data=None,
564
+ error=error_component
565
+ )
566
+
567
+ self.add_component(llm_component)
568
+ raise
569
+
570
+ def _extract_token_usage_sync(self, result):
571
+ """Sync version of extract token usage"""
572
+ # Handle coroutines
573
+ if asyncio.iscoroutine(result):
574
+ result = asyncio.run(result)
575
+
576
+ # Handle standard OpenAI/Anthropic format
577
+ if hasattr(result, "usage"):
578
+ usage = result.usage
579
+ return {
580
+ "prompt_tokens": getattr(usage, "prompt_tokens", 0),
581
+ "completion_tokens": getattr(usage, "completion_tokens", 0),
582
+ "total_tokens": getattr(usage, "total_tokens", 0)
583
+ }
584
+
585
+ # Handle Google GenerativeAI format with usage_metadata
586
+ if hasattr(result, "usage_metadata"):
587
+ metadata = result.usage_metadata
588
+ return {
589
+ "prompt_tokens": getattr(metadata, "prompt_token_count", 0),
590
+ "completion_tokens": getattr(metadata, "candidates_token_count", 0),
591
+ "total_tokens": getattr(metadata, "total_token_count", 0)
592
+ }
593
+
594
+ # Handle Vertex AI format
595
+ if hasattr(result, "text"):
596
+ # For LangChain ChatVertexAI
597
+ total_tokens = getattr(result, "token_count", 0)
598
+ if not total_tokens and hasattr(result, "_raw_response"):
599
+ # Try to get from raw response
600
+ total_tokens = getattr(result._raw_response, "token_count", 0)
601
+ return {
602
+ "prompt_tokens": 0, # Vertex AI doesn't provide this breakdown
603
+ "completion_tokens": total_tokens,
604
+ "total_tokens": total_tokens
605
+ }
606
+
607
+ return {
608
+ "prompt_tokens": 0,
609
+ "completion_tokens": 0,
610
+ "total_tokens": 0
611
+ }
612
+
613
+ def trace_llm_call_sync(self, original_func, *args, **kwargs):
614
+ """Sync version of trace_llm_call"""
615
+ if not self.is_active:
616
+ if asyncio.iscoroutinefunction(original_func):
617
+ return asyncio.run(original_func(*args, **kwargs))
618
+ return original_func(*args, **kwargs)
619
+
620
+ start_time = datetime.now().astimezone()
621
+ start_memory = psutil.Process().memory_info().rss
622
+ component_id = str(uuid.uuid4())
623
+ hash_id = self.trace_llm_call.hash_id
624
+
625
+ # Start tracking network calls for this component
626
+ self.start_component(component_id)
627
+
628
+ try:
629
+ # Execute the LLM call
630
+ result = None
631
+ if asyncio.iscoroutinefunction(original_func):
632
+ result = asyncio.run(original_func(*args, **kwargs))
633
+ else:
634
+ result = original_func(*args, **kwargs)
635
+
636
+ # If result is a coroutine, run it
637
+ if asyncio.iscoroutine(result):
638
+ result = asyncio.run(result)
639
+
640
+ # Calculate resource usage
641
+ end_time = datetime.now().astimezone()
642
+ end_memory = psutil.Process().memory_info().rss
643
+ memory_used = max(0, end_memory - start_memory)
644
+
645
+ # Extract token usage and calculate cost
646
+ token_usage = self._extract_token_usage_sync(result)
647
+ model_name = self._extract_model_name(kwargs)
648
+ cost = self._calculate_cost(token_usage, model_name)
649
+
650
+ # End tracking network calls for this component
651
+ self.end_component(component_id)
652
+
653
+ # Create LLM component
654
+ llm_component = self.create_llm_component(
655
+ component_id=component_id,
656
+ hash_id=hash_id,
657
+ name=self.current_llm_call_name.get(),
658
+ llm_type=model_name,
659
+ version="1.0.0",
660
+ memory_used=memory_used,
661
+ start_time=start_time,
662
+ end_time=end_time,
663
+ input_data=self._extract_input_data(kwargs, result),
664
+ output_data=extract_llm_output(result),
665
+ cost=cost,
666
+ usage=token_usage
667
+ )
668
+
669
+ self.add_component(llm_component)
670
+ return result
671
+
672
+ except Exception as e:
673
+ error_component = {
674
+ "code": 500,
675
+ "type": type(e).__name__,
676
+ "message": str(e),
677
+ "details": {}
678
+ }
679
+
680
+ # End tracking network calls for this component
681
+ self.end_component(component_id)
682
+
683
+ end_time = datetime.now().astimezone()
684
+
685
+ llm_component = self.create_llm_component(
686
+ component_id=component_id,
687
+ hash_id=hash_id,
688
+ name=self.current_llm_call_name.get(),
689
+ llm_type="unknown",
690
+ version="1.0.0",
691
+ memory_used=0,
692
+ start_time=start_time,
693
+ end_time=end_time,
694
+ input_data=self._extract_input_data(kwargs, None),
695
+ output_data=None,
696
+ error=error_component
697
+ )
698
+
699
+ self.add_component(llm_component)
700
+ raise
701
+
702
+ async def _extract_token_usage(self, result):
703
+ """Extract token usage from result"""
704
+ # Handle coroutines
705
+ if asyncio.iscoroutine(result):
706
+ result = await result
707
+
708
+ # Handle standard OpenAI/Anthropic format
709
+ if hasattr(result, "usage"):
710
+ usage = result.usage
711
+ return {
712
+ "prompt_tokens": getattr(usage, "prompt_tokens", 0),
713
+ "completion_tokens": getattr(usage, "completion_tokens", 0),
714
+ "total_tokens": getattr(usage, "total_tokens", 0)
715
+ }
716
+
717
+ # Handle Google GenerativeAI format with usage_metadata
718
+ if hasattr(result, "usage_metadata"):
719
+ metadata = result.usage_metadata
720
+ return {
721
+ "prompt_tokens": getattr(metadata, "prompt_token_count", 0),
722
+ "completion_tokens": getattr(metadata, "candidates_token_count", 0),
723
+ "total_tokens": getattr(metadata, "total_token_count", 0)
724
+ }
725
+
726
+ # Handle Vertex AI format
727
+ if hasattr(result, "text"):
728
+ # For LangChain ChatVertexAI
729
+ total_tokens = getattr(result, "token_count", 0)
730
+ if not total_tokens and hasattr(result, "_raw_response"):
731
+ # Try to get from raw response
732
+ total_tokens = getattr(result._raw_response, "token_count", 0)
733
+ return {
734
+ "prompt_tokens": 0, # Vertex AI doesn't provide this breakdown
735
+ "completion_tokens": total_tokens,
736
+ "total_tokens": total_tokens
737
+ }
738
+
739
+ return {
740
+ "prompt_tokens": 0,
741
+ "completion_tokens": 0,
742
+ "total_tokens": 0
743
+ }
744
+
745
+ def trace_llm(self, name: str, tool_type: str = "llm", version: str = "1.0.0"):
746
+ def decorator(func_or_class):
747
+ if isinstance(func_or_class, type):
748
+ for attr_name, attr_value in func_or_class.__dict__.items():
749
+ if callable(attr_value) and not attr_name.startswith("__"):
750
+ setattr(
751
+ func_or_class,
752
+ attr_name,
753
+ self.trace_llm(f"{name}.{attr_name}", tool_type, version)(attr_value),
754
+ )
755
+ return func_or_class
756
+ else:
757
+ @functools.wraps(func_or_class)
758
+ async def async_wrapper(*args, **kwargs):
759
+ token = self.current_llm_call_name.set(name)
760
+ try:
761
+ return await func_or_class(*args, **kwargs)
762
+ finally:
763
+ self.current_llm_call_name.reset(token)
764
+
765
+ @functools.wraps(func_or_class)
766
+ def sync_wrapper(*args, **kwargs):
767
+ token = self.current_llm_call_name.set(name)
768
+ try:
769
+ return func_or_class(*args, **kwargs)
770
+ finally:
771
+ self.current_llm_call_name.reset(token)
772
+
773
+ return async_wrapper if asyncio.iscoroutinefunction(func_or_class) else sync_wrapper
774
+
775
+ return decorator
776
+
777
+ def unpatch_llm_calls(self):
778
+ """Remove all patches"""
779
+ for obj, method_name, original_method in self.patches:
780
+ if hasattr(obj, method_name):
781
+ setattr(obj, method_name, original_method)
782
+ self.patches.clear()
783
+
784
+ def _sanitize_api_keys(self, data):
785
+ """Remove sensitive information from data"""
786
+ if isinstance(data, dict):
787
+ return {k: self._sanitize_api_keys(v) for k, v in data.items()
788
+ if not any(sensitive in k.lower() for sensitive in ['key', 'token', 'secret', 'password'])}
789
+ elif isinstance(data, list):
790
+ return [self._sanitize_api_keys(item) for item in data]
791
+ elif isinstance(data, tuple):
792
+ return tuple(self._sanitize_api_keys(item) for item in data)
793
+ return data
794
+
795
+ def _create_llm_component(self, component_id, hash_id, name, llm_type, version, memory_used, start_time, end_time, input_data, output_data, usage=None, error=None):
796
+ cost = None
797
+ tokens = None
798
+
799
+ if usage:
800
+ tokens = {
801
+ "prompt_tokens": usage.get("prompt_tokens", 0),
802
+ "completion_tokens": usage.get("completion_tokens", 0),
803
+ "total_tokens": usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)
804
+ }
805
+ cost = calculate_cost(usage)
806
+
807
+ # Update total metrics
808
+ self.total_tokens += tokens["total_tokens"]
809
+ self.total_cost += cost["total"]
810
+
811
+ component = {
812
+ "id": component_id,
813
+ "hash_id": hash_id,
814
+ "source_hash_id": None,
815
+ "type": "llm",
816
+ "name": name,
817
+ "start_time": start_time.isoformat(),
818
+ "end_time": end_time.isoformat(),
819
+ "error": error,
820
+ "parent_id": self.current_agent_id.get(),
821
+ "info": {
822
+ "llm_type": llm_type,
823
+ "version": version,
824
+ "memory_used": memory_used,
825
+ "cost": cost,
826
+ "tokens": tokens
827
+ },
828
+ "data": {
829
+ "input": input_data,
830
+ "output": output_data.output_response if output_data else None,
831
+ "memory_used": memory_used
832
+ },
833
+ "network_calls": self.component_network_calls.get(component_id, []),
834
+ "interactions": [
835
+ {
836
+ "id": f"int_{uuid.uuid4()}",
837
+ "interaction_type": "input",
838
+ "timestamp": start_time.isoformat(),
839
+ "content": input_data
840
+ },
841
+ {
842
+ "id": f"int_{uuid.uuid4()}",
843
+ "interaction_type": "output",
844
+ "timestamp": end_time.isoformat(),
845
+ "content": output_data.output_response if output_data else None
846
+ }
847
+ ]
848
+ }
849
+
850
+ return component
851
+
852
+ def extract_llm_output(result):
853
+ """Extract output from LLM response"""
854
+ class OutputResponse:
855
+ def __init__(self, output_response):
856
+ self.output_response = output_response
857
+
858
+ # Handle coroutines
859
+ if asyncio.iscoroutine(result):
860
+ # For sync context, run the coroutine
861
+ if not asyncio.get_event_loop().is_running():
862
+ result = asyncio.run(result)
863
+ else:
864
+ # We're in an async context, but this function is called synchronously
865
+ # Return a placeholder and let the caller handle the coroutine
866
+ return OutputResponse("Coroutine result pending")
867
+
868
+ # Handle Google GenerativeAI format
869
+ if hasattr(result, "result"):
870
+ candidates = getattr(result.result, "candidates", [])
871
+ output = []
872
+ for candidate in candidates:
873
+ content = getattr(candidate, "content", None)
874
+ if content and hasattr(content, "parts"):
875
+ for part in content.parts:
876
+ if hasattr(part, "text"):
877
+ output.append({
878
+ "content": part.text,
879
+ "role": getattr(content, "role", "assistant"),
880
+ "finish_reason": getattr(candidate, "finish_reason", None)
881
+ })
882
+ return OutputResponse(output)
883
+
884
+ # Handle Vertex AI format
885
+ if hasattr(result, "text"):
886
+ return OutputResponse([{
887
+ "content": result.text,
888
+ "role": "assistant"
889
+ }])
890
+
891
+ # Handle OpenAI format
892
+ if hasattr(result, "choices"):
893
+ return OutputResponse([{
894
+ "content": choice.message.content,
895
+ "role": choice.message.role
896
+ } for choice in result.choices])
897
+
898
+ # Handle Anthropic format
899
+ if hasattr(result, "completion"):
900
+ return OutputResponse([{
901
+ "content": result.completion,
902
+ "role": "assistant"
903
+ }])
904
+
905
+ # Default case
906
+ return OutputResponse(str(result))