ragaai-catalyst 2.1b0__py3-none-any.whl → 2.1b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragaai_catalyst/__init__.py +1 -0
- ragaai_catalyst/dataset.py +1 -4
- ragaai_catalyst/evaluation.py +4 -5
- ragaai_catalyst/guard_executor.py +97 -0
- ragaai_catalyst/guardrails_manager.py +41 -15
- ragaai_catalyst/internal_api_completion.py +1 -1
- ragaai_catalyst/prompt_manager.py +7 -2
- ragaai_catalyst/ragaai_catalyst.py +1 -1
- ragaai_catalyst/synthetic_data_generation.py +7 -0
- ragaai_catalyst/tracers/__init__.py +1 -1
- ragaai_catalyst/tracers/agentic_tracing/__init__.py +3 -0
- ragaai_catalyst/tracers/agentic_tracing/agent_tracer.py +422 -0
- ragaai_catalyst/tracers/agentic_tracing/agentic_tracing.py +198 -0
- ragaai_catalyst/tracers/agentic_tracing/base.py +376 -0
- ragaai_catalyst/tracers/agentic_tracing/data_structure.py +248 -0
- ragaai_catalyst/tracers/agentic_tracing/examples/FinancialAnalysisSystem.ipynb +536 -0
- ragaai_catalyst/tracers/agentic_tracing/examples/GameActivityEventPlanner.ipynb +134 -0
- ragaai_catalyst/tracers/agentic_tracing/examples/TravelPlanner.ipynb +563 -0
- ragaai_catalyst/tracers/agentic_tracing/file_name_tracker.py +46 -0
- ragaai_catalyst/tracers/agentic_tracing/llm_tracer.py +808 -0
- ragaai_catalyst/tracers/agentic_tracing/network_tracer.py +286 -0
- ragaai_catalyst/tracers/agentic_tracing/sample.py +197 -0
- ragaai_catalyst/tracers/agentic_tracing/tool_tracer.py +247 -0
- ragaai_catalyst/tracers/agentic_tracing/unique_decorator.py +165 -0
- ragaai_catalyst/tracers/agentic_tracing/unique_decorator_test.py +172 -0
- ragaai_catalyst/tracers/agentic_tracing/upload_agentic_traces.py +187 -0
- ragaai_catalyst/tracers/agentic_tracing/upload_code.py +115 -0
- ragaai_catalyst/tracers/agentic_tracing/user_interaction_tracer.py +43 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/__init__.py +3 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/api_utils.py +18 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/data_classes.py +61 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/generic.py +32 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +177 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +7823 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +74 -0
- ragaai_catalyst/tracers/agentic_tracing/zip_list_of_unique_files.py +184 -0
- ragaai_catalyst/tracers/exporters/raga_exporter.py +1 -7
- ragaai_catalyst/tracers/tracer.py +30 -4
- ragaai_catalyst/tracers/upload_traces.py +127 -0
- ragaai_catalyst-2.1b2.dist-info/METADATA +43 -0
- ragaai_catalyst-2.1b2.dist-info/RECORD +56 -0
- {ragaai_catalyst-2.1b0.dist-info → ragaai_catalyst-2.1b2.dist-info}/WHEEL +1 -1
- ragaai_catalyst-2.1b0.dist-info/METADATA +0 -295
- ragaai_catalyst-2.1b0.dist-info/RECORD +0 -28
- {ragaai_catalyst-2.1b0.dist-info → ragaai_catalyst-2.1b2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,808 @@
|
|
1
|
+
from typing import Optional, Any, Dict, List
|
2
|
+
import asyncio
|
3
|
+
import psutil
|
4
|
+
import wrapt
|
5
|
+
import functools
|
6
|
+
from datetime import datetime
|
7
|
+
import uuid
|
8
|
+
import contextvars
|
9
|
+
import traceback
|
10
|
+
|
11
|
+
from .unique_decorator import generate_unique_hash_simple
|
12
|
+
from .utils.trace_utils import load_model_costs
|
13
|
+
from .utils.llm_utils import extract_llm_output
|
14
|
+
from .file_name_tracker import TrackName
|
15
|
+
|
16
|
+
|
17
|
+
class LLMTracerMixin:
|
18
|
+
def __init__(self, *args, **kwargs):
|
19
|
+
super().__init__(*args, **kwargs)
|
20
|
+
self.file_tracker = TrackName()
|
21
|
+
self.patches = []
|
22
|
+
try:
|
23
|
+
self.model_costs = load_model_costs()
|
24
|
+
except Exception as e:
|
25
|
+
self.model_costs = {
|
26
|
+
# TODO: Default cost handling needs to be improved
|
27
|
+
"default": {
|
28
|
+
"input_cost_per_token": 0.0,
|
29
|
+
"output_cost_per_token": 0.0
|
30
|
+
}
|
31
|
+
}
|
32
|
+
self.current_llm_call_name = contextvars.ContextVar("llm_call_name", default=None)
|
33
|
+
self.component_network_calls = {}
|
34
|
+
self.component_user_interaction = {}
|
35
|
+
self.current_component_id = None
|
36
|
+
self.total_tokens = 0
|
37
|
+
self.total_cost = 0.0
|
38
|
+
self.llm_data = {}
|
39
|
+
|
40
|
+
def instrument_llm_calls(self):
|
41
|
+
# Handle modules that are already imported
|
42
|
+
import sys
|
43
|
+
|
44
|
+
if "vertexai" in sys.modules:
|
45
|
+
self.patch_vertex_ai_methods(sys.modules["vertexai"])
|
46
|
+
if "vertexai.generative_models" in sys.modules:
|
47
|
+
self.patch_vertex_ai_methods(sys.modules["vertexai.generative_models"])
|
48
|
+
|
49
|
+
if "openai" in sys.modules:
|
50
|
+
self.patch_openai_methods(sys.modules["openai"])
|
51
|
+
if "litellm" in sys.modules:
|
52
|
+
self.patch_litellm_methods(sys.modules["litellm"])
|
53
|
+
if "anthropic" in sys.modules:
|
54
|
+
self.patch_anthropic_methods(sys.modules["anthropic"])
|
55
|
+
if "google.generativeai" in sys.modules:
|
56
|
+
self.patch_google_genai_methods(sys.modules["google.generativeai"])
|
57
|
+
if "langchain_google_vertexai" in sys.modules:
|
58
|
+
self.patch_langchain_google_methods(sys.modules["langchain_google_vertexai"])
|
59
|
+
if "langchain_google_genai" in sys.modules:
|
60
|
+
self.patch_langchain_google_methods(sys.modules["langchain_google_genai"])
|
61
|
+
|
62
|
+
# Register hooks for future imports
|
63
|
+
wrapt.register_post_import_hook(self.patch_vertex_ai_methods, "vertexai")
|
64
|
+
wrapt.register_post_import_hook(self.patch_vertex_ai_methods, "vertexai.generative_models")
|
65
|
+
wrapt.register_post_import_hook(self.patch_openai_methods, "openai")
|
66
|
+
wrapt.register_post_import_hook(self.patch_litellm_methods, "litellm")
|
67
|
+
wrapt.register_post_import_hook(self.patch_anthropic_methods, "anthropic")
|
68
|
+
wrapt.register_post_import_hook(self.patch_google_genai_methods, "google.generativeai")
|
69
|
+
|
70
|
+
# Add hooks for LangChain integrations
|
71
|
+
wrapt.register_post_import_hook(self.patch_langchain_google_methods, "langchain_google_vertexai")
|
72
|
+
wrapt.register_post_import_hook(self.patch_langchain_google_methods, "langchain_google_genai")
|
73
|
+
|
74
|
+
def patch_openai_methods(self, module):
|
75
|
+
try:
|
76
|
+
if hasattr(module, "OpenAI"):
|
77
|
+
client_class = getattr(module, "OpenAI")
|
78
|
+
self.wrap_openai_client_methods(client_class)
|
79
|
+
if hasattr(module, "AsyncOpenAI"):
|
80
|
+
async_client_class = getattr(module, "AsyncOpenAI")
|
81
|
+
self.wrap_openai_client_methods(async_client_class)
|
82
|
+
except Exception as e:
|
83
|
+
# Log the error but continue execution
|
84
|
+
print(f"Warning: Failed to patch OpenAI methods: {str(e)}")
|
85
|
+
|
86
|
+
def patch_anthropic_methods(self, module):
|
87
|
+
if hasattr(module, "Anthropic"):
|
88
|
+
client_class = getattr(module, "Anthropic")
|
89
|
+
self.wrap_anthropic_client_methods(client_class)
|
90
|
+
|
91
|
+
def patch_google_genai_methods(self, module):
|
92
|
+
# Patch direct Google GenerativeAI usage
|
93
|
+
if hasattr(module, "GenerativeModel"):
|
94
|
+
model_class = getattr(module, "GenerativeModel")
|
95
|
+
self.wrap_genai_model_methods(model_class)
|
96
|
+
|
97
|
+
# Patch LangChain integration
|
98
|
+
if hasattr(module, "ChatGoogleGenerativeAI"):
|
99
|
+
chat_class = getattr(module, "ChatGoogleGenerativeAI")
|
100
|
+
# Wrap invoke method to capture messages
|
101
|
+
original_invoke = chat_class.invoke
|
102
|
+
|
103
|
+
def patched_invoke(self, messages, *args, **kwargs):
|
104
|
+
# Store messages in the instance for later use
|
105
|
+
self._last_messages = messages
|
106
|
+
return original_invoke(self, messages, *args, **kwargs)
|
107
|
+
|
108
|
+
chat_class.invoke = patched_invoke
|
109
|
+
|
110
|
+
# LangChain v0.2+ uses invoke/ainvoke
|
111
|
+
self.wrap_method(chat_class, "_generate")
|
112
|
+
if hasattr(chat_class, "_agenerate"):
|
113
|
+
self.wrap_method(chat_class, "_agenerate")
|
114
|
+
# Fallback for completion methods
|
115
|
+
if hasattr(chat_class, "complete"):
|
116
|
+
self.wrap_method(chat_class, "complete")
|
117
|
+
if hasattr(chat_class, "acomplete"):
|
118
|
+
self.wrap_method(chat_class, "acomplete")
|
119
|
+
|
120
|
+
def patch_vertex_ai_methods(self, module):
|
121
|
+
# Patch the GenerativeModel class
|
122
|
+
if hasattr(module, "generative_models"):
|
123
|
+
gen_models = getattr(module, "generative_models")
|
124
|
+
if hasattr(gen_models, "GenerativeModel"):
|
125
|
+
model_class = getattr(gen_models, "GenerativeModel")
|
126
|
+
self.wrap_vertex_model_methods(model_class)
|
127
|
+
|
128
|
+
# Also patch the class directly if available
|
129
|
+
if hasattr(module, "GenerativeModel"):
|
130
|
+
model_class = getattr(module, "GenerativeModel")
|
131
|
+
self.wrap_vertex_model_methods(model_class)
|
132
|
+
|
133
|
+
def wrap_vertex_model_methods(self, model_class):
|
134
|
+
# Patch both sync and async methods
|
135
|
+
self.wrap_method(model_class, "generate_content")
|
136
|
+
if hasattr(model_class, "generate_content_async"):
|
137
|
+
self.wrap_method(model_class, "generate_content_async")
|
138
|
+
|
139
|
+
def patch_litellm_methods(self, module):
|
140
|
+
self.wrap_method(module, "completion")
|
141
|
+
self.wrap_method(module, "acompletion")
|
142
|
+
|
143
|
+
def patch_langchain_google_methods(self, module):
|
144
|
+
"""Patch LangChain's Google integration methods"""
|
145
|
+
if hasattr(module, "ChatVertexAI"):
|
146
|
+
chat_class = getattr(module, "ChatVertexAI")
|
147
|
+
# LangChain v0.2+ uses invoke/ainvoke
|
148
|
+
self.wrap_method(chat_class, "_generate")
|
149
|
+
if hasattr(chat_class, "_agenerate"):
|
150
|
+
self.wrap_method(chat_class, "_agenerate")
|
151
|
+
# Fallback for completion methods
|
152
|
+
if hasattr(chat_class, "complete"):
|
153
|
+
self.wrap_method(chat_class, "complete")
|
154
|
+
if hasattr(chat_class, "acomplete"):
|
155
|
+
self.wrap_method(chat_class, "acomplete")
|
156
|
+
|
157
|
+
if hasattr(module, "ChatGoogleGenerativeAI"):
|
158
|
+
chat_class = getattr(module, "ChatGoogleGenerativeAI")
|
159
|
+
# LangChain v0.2+ uses invoke/ainvoke
|
160
|
+
self.wrap_method(chat_class, "_generate")
|
161
|
+
if hasattr(chat_class, "_agenerate"):
|
162
|
+
self.wrap_method(chat_class, "_agenerate")
|
163
|
+
# Fallback for completion methods
|
164
|
+
if hasattr(chat_class, "complete"):
|
165
|
+
self.wrap_method(chat_class, "complete")
|
166
|
+
if hasattr(chat_class, "acomplete"):
|
167
|
+
self.wrap_method(chat_class, "acomplete")
|
168
|
+
|
169
|
+
def wrap_openai_client_methods(self, client_class):
|
170
|
+
original_init = client_class.__init__
|
171
|
+
|
172
|
+
@functools.wraps(original_init)
|
173
|
+
def patched_init(client_self, *args, **kwargs):
|
174
|
+
original_init(client_self, *args, **kwargs)
|
175
|
+
self.wrap_method(client_self.chat.completions, "create")
|
176
|
+
if hasattr(client_self.chat.completions, "acreate"):
|
177
|
+
self.wrap_method(client_self.chat.completions, "acreate")
|
178
|
+
|
179
|
+
setattr(client_class, "__init__", patched_init)
|
180
|
+
|
181
|
+
def wrap_anthropic_client_methods(self, client_class):
|
182
|
+
original_init = client_class.__init__
|
183
|
+
|
184
|
+
@functools.wraps(original_init)
|
185
|
+
def patched_init(client_self, *args, **kwargs):
|
186
|
+
original_init(client_self, *args, **kwargs)
|
187
|
+
self.wrap_method(client_self.messages, "create")
|
188
|
+
if hasattr(client_self.messages, "acreate"):
|
189
|
+
self.wrap_method(client_self.messages, "acreate")
|
190
|
+
|
191
|
+
setattr(client_class, "__init__", patched_init)
|
192
|
+
|
193
|
+
def wrap_genai_model_methods(self, model_class):
|
194
|
+
original_init = model_class.__init__
|
195
|
+
|
196
|
+
@functools.wraps(original_init)
|
197
|
+
def patched_init(model_self, *args, **kwargs):
|
198
|
+
original_init(model_self, *args, **kwargs)
|
199
|
+
self.wrap_method(model_self, "generate_content")
|
200
|
+
if hasattr(model_self, "generate_content_async"):
|
201
|
+
self.wrap_method(model_self, "generate_content_async")
|
202
|
+
|
203
|
+
setattr(model_class, "__init__", patched_init)
|
204
|
+
|
205
|
+
def wrap_method(self, obj, method_name):
|
206
|
+
"""
|
207
|
+
Wrap a method with tracing functionality.
|
208
|
+
Works for both class methods and instance methods.
|
209
|
+
"""
|
210
|
+
# If obj is a class, we need to patch both the class and any existing instances
|
211
|
+
if isinstance(obj, type):
|
212
|
+
# Store the original class method
|
213
|
+
original_method = getattr(obj, method_name)
|
214
|
+
|
215
|
+
@wrapt.decorator
|
216
|
+
def wrapper(wrapped, instance, args, kwargs):
|
217
|
+
if asyncio.iscoroutinefunction(wrapped):
|
218
|
+
return self.trace_llm_call(wrapped, *args, **kwargs)
|
219
|
+
return self.trace_llm_call_sync(wrapped, *args, **kwargs)
|
220
|
+
|
221
|
+
# Wrap the class method
|
222
|
+
wrapped_method = wrapper(original_method)
|
223
|
+
setattr(obj, method_name, wrapped_method)
|
224
|
+
self.patches.append((obj, method_name, original_method))
|
225
|
+
|
226
|
+
else:
|
227
|
+
# For instance methods
|
228
|
+
original_method = getattr(obj, method_name)
|
229
|
+
|
230
|
+
@wrapt.decorator
|
231
|
+
def wrapper(wrapped, instance, args, kwargs):
|
232
|
+
if asyncio.iscoroutinefunction(wrapped):
|
233
|
+
return self.trace_llm_call(wrapped, *args, **kwargs)
|
234
|
+
return self.trace_llm_call_sync(wrapped, *args, **kwargs)
|
235
|
+
|
236
|
+
wrapped_method = wrapper(original_method)
|
237
|
+
setattr(obj, method_name, wrapped_method)
|
238
|
+
self.patches.append((obj, method_name, original_method))
|
239
|
+
|
240
|
+
def _extract_model_name(self, args, kwargs, result):
|
241
|
+
"""Extract model name from kwargs or result"""
|
242
|
+
# First try direct model parameter
|
243
|
+
model = kwargs.get("model", "")
|
244
|
+
|
245
|
+
if not model:
|
246
|
+
# Try to get from instance
|
247
|
+
instance = kwargs.get("self", None)
|
248
|
+
if instance:
|
249
|
+
# Try model_name first (Google format)
|
250
|
+
if hasattr(instance, "model_name"):
|
251
|
+
model = instance.model_name
|
252
|
+
# Try model attribute
|
253
|
+
elif hasattr(instance, "model"):
|
254
|
+
model = instance.model
|
255
|
+
|
256
|
+
# TODO: This way isn't scalable. The necessity for normalising model names needs to be fixed. We shouldn't have to do this
|
257
|
+
# Normalize Google model names
|
258
|
+
if model and isinstance(model, str):
|
259
|
+
model = model.lower()
|
260
|
+
if "gemini-1.5-flash" in model:
|
261
|
+
return "gemini-1.5-flash"
|
262
|
+
if "gemini-1.5-pro" in model:
|
263
|
+
return "gemini-1.5-pro"
|
264
|
+
if "gemini-pro" in model:
|
265
|
+
return "gemini-pro"
|
266
|
+
|
267
|
+
if 'to_dict' in dir(result):
|
268
|
+
result = result.to_dict()
|
269
|
+
if 'model_version' in result:
|
270
|
+
model = result['model_version']
|
271
|
+
|
272
|
+
return model or "default"
|
273
|
+
|
274
|
+
def _extract_parameters(self, kwargs):
|
275
|
+
"""Extract all non-null parameters from kwargs"""
|
276
|
+
parameters = {k: v for k, v in kwargs.items() if v is not None}
|
277
|
+
|
278
|
+
# Remove contents key in parameters (Google LLM Response)
|
279
|
+
if 'contents' in parameters:
|
280
|
+
del parameters['contents']
|
281
|
+
|
282
|
+
# Remove messages key in parameters (OpenAI message)
|
283
|
+
if 'messages' in parameters:
|
284
|
+
del parameters['messages']
|
285
|
+
|
286
|
+
if 'generation_config' in parameters:
|
287
|
+
generation_config = parameters['generation_config']
|
288
|
+
# If generation_config is already a dict, use it directly
|
289
|
+
if isinstance(generation_config, dict):
|
290
|
+
config_dict = generation_config
|
291
|
+
else:
|
292
|
+
# Convert GenerationConfig to dictionary if it has a to_dict method, otherwise try to get its __dict__
|
293
|
+
config_dict = getattr(generation_config, 'to_dict', lambda: generation_config.__dict__)()
|
294
|
+
parameters.update(config_dict)
|
295
|
+
del parameters['generation_config']
|
296
|
+
|
297
|
+
return parameters
|
298
|
+
|
299
|
+
def _extract_token_usage(self, result):
|
300
|
+
"""Extract token usage from result"""
|
301
|
+
# Handle coroutines
|
302
|
+
if asyncio.iscoroutine(result):
|
303
|
+
# Get the current event loop
|
304
|
+
loop = asyncio.get_event_loop()
|
305
|
+
# Run the coroutine in the current event loop
|
306
|
+
result = loop.run_until_complete(result)
|
307
|
+
|
308
|
+
|
309
|
+
# Handle standard OpenAI/Anthropic format
|
310
|
+
if hasattr(result, "usage"):
|
311
|
+
usage = result.usage
|
312
|
+
return {
|
313
|
+
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
314
|
+
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
315
|
+
"total_tokens": getattr(usage, "total_tokens", 0)
|
316
|
+
}
|
317
|
+
|
318
|
+
# Handle Google GenerativeAI format with usage_metadata
|
319
|
+
if hasattr(result, "usage_metadata"):
|
320
|
+
metadata = result.usage_metadata
|
321
|
+
return {
|
322
|
+
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
|
323
|
+
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
|
324
|
+
"total_tokens": getattr(metadata, "total_token_count", 0)
|
325
|
+
}
|
326
|
+
|
327
|
+
# Handle Vertex AI format
|
328
|
+
if hasattr(result, "text"):
|
329
|
+
# For LangChain ChatVertexAI
|
330
|
+
total_tokens = getattr(result, "token_count", 0)
|
331
|
+
if not total_tokens and hasattr(result, "_raw_response"):
|
332
|
+
# Try to get from raw response
|
333
|
+
total_tokens = getattr(result._raw_response, "token_count", 0)
|
334
|
+
return {
|
335
|
+
# TODO: This implementation is incorrect. Vertex AI does provide this breakdown
|
336
|
+
"prompt_tokens": 0, # Vertex AI doesn't provide this breakdown
|
337
|
+
"completion_tokens": total_tokens,
|
338
|
+
"total_tokens": total_tokens
|
339
|
+
}
|
340
|
+
|
341
|
+
return { # TODO: Passing 0 in case of not recorded is not correct. This needs to be fixes. Discuss before making changes to this
|
342
|
+
"prompt_tokens": 0,
|
343
|
+
"completion_tokens": 0,
|
344
|
+
"total_tokens": 0
|
345
|
+
}
|
346
|
+
|
347
|
+
def _extract_input_data(self, args, kwargs, result):
|
348
|
+
return {
|
349
|
+
'args': args,
|
350
|
+
'kwargs': kwargs
|
351
|
+
}
|
352
|
+
|
353
|
+
def _calculate_cost(self, token_usage, model_name):
|
354
|
+
# TODO: Passing default cost is a faulty logic & implementation and should be fixed
|
355
|
+
"""Calculate cost based on token usage and model"""
|
356
|
+
if not isinstance(token_usage, dict):
|
357
|
+
token_usage = {
|
358
|
+
"prompt_tokens": 0,
|
359
|
+
"completion_tokens": 0,
|
360
|
+
"total_tokens": token_usage if isinstance(token_usage, (int, float)) else 0
|
361
|
+
}
|
362
|
+
|
363
|
+
# TODO: This is a temporary fix. This needs to be fixed
|
364
|
+
|
365
|
+
# Get model costs, defaulting to Vertex AI PaLM2 costs if unknown
|
366
|
+
model_cost = self.model_costs.get(model_name, {
|
367
|
+
"input_cost_per_token": 0.0,
|
368
|
+
"output_cost_per_token": 0.0
|
369
|
+
})
|
370
|
+
|
371
|
+
input_cost = (token_usage.get("prompt_tokens", 0)) * model_cost.get("input_cost_per_token", 0.0)
|
372
|
+
output_cost = (token_usage.get("completion_tokens", 0)) * model_cost.get("output_cost_per_token", 0.0)
|
373
|
+
total_cost = input_cost + output_cost
|
374
|
+
|
375
|
+
# TODO: Return the value as it is, no need to round
|
376
|
+
return {
|
377
|
+
"input_cost": round(input_cost, 10),
|
378
|
+
"output_cost": round(output_cost, 10),
|
379
|
+
"total_cost": round(total_cost, 10)
|
380
|
+
}
|
381
|
+
|
382
|
+
def create_llm_component(self, component_id, hash_id, name, llm_type, version, memory_used, start_time, end_time, input_data, output_data, cost={}, usage={}, error=None, parameters={}):
|
383
|
+
# Update total metrics
|
384
|
+
self.total_tokens += usage.get("total_tokens", 0)
|
385
|
+
self.total_cost += cost.get("total_cost", 0)
|
386
|
+
|
387
|
+
component = {
|
388
|
+
"id": component_id,
|
389
|
+
"hash_id": hash_id,
|
390
|
+
"source_hash_id": None,
|
391
|
+
"type": "llm",
|
392
|
+
"name": name,
|
393
|
+
"start_time": start_time.isoformat(),
|
394
|
+
"end_time": end_time.isoformat(),
|
395
|
+
"error": error,
|
396
|
+
"parent_id": self.current_agent_id.get(),
|
397
|
+
"info": {
|
398
|
+
"model": llm_type,
|
399
|
+
"version": version,
|
400
|
+
"memory_used": memory_used,
|
401
|
+
"cost": cost,
|
402
|
+
"tokens": usage,
|
403
|
+
**parameters
|
404
|
+
},
|
405
|
+
"data": {
|
406
|
+
"input": input_data['args'] if hasattr(input_data, 'args') else input_data,
|
407
|
+
"output": output_data.output_response if output_data else None,
|
408
|
+
"memory_used": memory_used
|
409
|
+
},
|
410
|
+
"network_calls": self.component_network_calls.get(component_id, []),
|
411
|
+
"interactions": self.component_user_interaction.get(component_id, [])
|
412
|
+
}
|
413
|
+
|
414
|
+
if self.gt:
|
415
|
+
component["data"]["gt"] = self.gt
|
416
|
+
|
417
|
+
return component
|
418
|
+
|
419
|
+
def start_component(self, component_id):
|
420
|
+
"""Start tracking network calls for a component"""
|
421
|
+
self.component_network_calls[component_id] = []
|
422
|
+
self.current_component_id = component_id
|
423
|
+
|
424
|
+
def end_component(self, component_id):
|
425
|
+
"""Stop tracking network calls for a component"""
|
426
|
+
self.current_component_id = None
|
427
|
+
|
428
|
+
|
429
|
+
async def trace_llm_call(self, original_func, *args, **kwargs):
|
430
|
+
"""Trace an LLM API call"""
|
431
|
+
if not self.is_active:
|
432
|
+
return await original_func(*args, **kwargs)
|
433
|
+
|
434
|
+
start_time = datetime.now().astimezone()
|
435
|
+
start_memory = psutil.Process().memory_info().rss
|
436
|
+
component_id = str(uuid.uuid4())
|
437
|
+
hash_id = generate_unique_hash_simple(original_func)
|
438
|
+
|
439
|
+
# Start tracking network calls for this component
|
440
|
+
self.start_component(component_id)
|
441
|
+
|
442
|
+
try:
|
443
|
+
# Execute the LLM call
|
444
|
+
result = await original_func(*args, **kwargs)
|
445
|
+
|
446
|
+
# Calculate resource usage
|
447
|
+
end_time = datetime.now().astimezone()
|
448
|
+
end_memory = psutil.Process().memory_info().rss
|
449
|
+
memory_used = max(0, end_memory - start_memory)
|
450
|
+
|
451
|
+
# Extract token usage and calculate cost
|
452
|
+
token_usage = self._extract_token_usage(result)
|
453
|
+
model_name = self._extract_model_name(args, kwargs, result)
|
454
|
+
cost = self._calculate_cost(token_usage, model_name)
|
455
|
+
parameters = self._extract_parameters(kwargs)
|
456
|
+
|
457
|
+
# End tracking network calls for this component
|
458
|
+
self.end_component(component_id)
|
459
|
+
|
460
|
+
name = self.current_llm_call_name.get()
|
461
|
+
if name is None:
|
462
|
+
name = original_func.__name__
|
463
|
+
|
464
|
+
# Create input data with ground truth
|
465
|
+
input_data = self._extract_input_data(args, kwargs, result)
|
466
|
+
|
467
|
+
# Create LLM component
|
468
|
+
llm_component = self.create_llm_component(
|
469
|
+
component_id=component_id,
|
470
|
+
hash_id=hash_id,
|
471
|
+
name=name,
|
472
|
+
llm_type=model_name,
|
473
|
+
version="1.0.0",
|
474
|
+
memory_used=memory_used,
|
475
|
+
start_time=start_time,
|
476
|
+
end_time=end_time,
|
477
|
+
input_data=input_data,
|
478
|
+
output_data=extract_llm_output(result),
|
479
|
+
cost=cost,
|
480
|
+
usage=token_usage,
|
481
|
+
parameters=parameters
|
482
|
+
)
|
483
|
+
|
484
|
+
# self.add_component(llm_component)
|
485
|
+
self.llm_data = llm_component
|
486
|
+
return result
|
487
|
+
|
488
|
+
except Exception as e:
|
489
|
+
error_component = {
|
490
|
+
"code": 500,
|
491
|
+
"type": type(e).__name__,
|
492
|
+
"message": str(e),
|
493
|
+
"details": {}
|
494
|
+
}
|
495
|
+
|
496
|
+
# End tracking network calls for this component
|
497
|
+
self.end_component(component_id)
|
498
|
+
|
499
|
+
end_time = datetime.now().astimezone()
|
500
|
+
|
501
|
+
name = self.current_llm_call_name.get()
|
502
|
+
if name is None:
|
503
|
+
name = original_func.__name__
|
504
|
+
|
505
|
+
llm_component = self.create_llm_component(
|
506
|
+
component_id=component_id,
|
507
|
+
hash_id=hash_id,
|
508
|
+
name=name,
|
509
|
+
llm_type="unknown",
|
510
|
+
version="1.0.0",
|
511
|
+
memory_used=0,
|
512
|
+
start_time=start_time,
|
513
|
+
end_time=end_time,
|
514
|
+
input_data=self._extract_input_data(args, kwargs, None),
|
515
|
+
output_data=None,
|
516
|
+
error=error_component
|
517
|
+
)
|
518
|
+
|
519
|
+
self.add_component(llm_component)
|
520
|
+
raise
|
521
|
+
|
522
|
+
def trace_llm_call_sync(self, original_func, *args, **kwargs):
|
523
|
+
"""Sync version of trace_llm_call"""
|
524
|
+
if not self.is_active:
|
525
|
+
if asyncio.iscoroutinefunction(original_func):
|
526
|
+
return asyncio.run(original_func(*args, **kwargs))
|
527
|
+
return original_func(*args, **kwargs)
|
528
|
+
|
529
|
+
start_time = datetime.now().astimezone()
|
530
|
+
component_id = str(uuid.uuid4())
|
531
|
+
hash_id = generate_unique_hash_simple(original_func)
|
532
|
+
|
533
|
+
# Start tracking network calls for this component
|
534
|
+
self.start_component(component_id)
|
535
|
+
|
536
|
+
# Calculate resource usage
|
537
|
+
end_time = datetime.now().astimezone()
|
538
|
+
start_memory = psutil.Process().memory_info().rss
|
539
|
+
|
540
|
+
try:
|
541
|
+
# Execute the function
|
542
|
+
if asyncio.iscoroutinefunction(original_func):
|
543
|
+
result = asyncio.run(original_func(*args, **kwargs))
|
544
|
+
else:
|
545
|
+
result = original_func(*args, **kwargs)
|
546
|
+
|
547
|
+
end_memory = psutil.Process().memory_info().rss
|
548
|
+
memory_used = max(0, end_memory - start_memory)
|
549
|
+
|
550
|
+
# Extract token usage and calculate cost
|
551
|
+
token_usage = self._extract_token_usage(result)
|
552
|
+
model_name = self._extract_model_name(args, kwargs, result)
|
553
|
+
cost = self._calculate_cost(token_usage, model_name)
|
554
|
+
parameters = self._extract_parameters(kwargs)
|
555
|
+
|
556
|
+
# End tracking network calls for this component
|
557
|
+
self.end_component(component_id)
|
558
|
+
|
559
|
+
name = self.current_llm_call_name.get()
|
560
|
+
if name is None:
|
561
|
+
name = original_func.__name__
|
562
|
+
|
563
|
+
# Create input data with ground truth
|
564
|
+
input_data = self._extract_input_data(args, kwargs, result)
|
565
|
+
|
566
|
+
# Create LLM component
|
567
|
+
llm_component = self.create_llm_component(
|
568
|
+
component_id=component_id,
|
569
|
+
hash_id=hash_id,
|
570
|
+
name=name,
|
571
|
+
llm_type=model_name,
|
572
|
+
version="1.0.0",
|
573
|
+
memory_used=memory_used,
|
574
|
+
start_time=start_time,
|
575
|
+
end_time=end_time,
|
576
|
+
input_data=input_data,
|
577
|
+
output_data=extract_llm_output(result),
|
578
|
+
cost=cost,
|
579
|
+
usage=token_usage,
|
580
|
+
parameters=parameters
|
581
|
+
)
|
582
|
+
|
583
|
+
self.add_component(llm_component)
|
584
|
+
return result
|
585
|
+
|
586
|
+
except Exception as e:
|
587
|
+
error_component = {
|
588
|
+
"code": 500,
|
589
|
+
"type": type(e).__name__,
|
590
|
+
"message": str(e),
|
591
|
+
"details": {}
|
592
|
+
}
|
593
|
+
|
594
|
+
# End tracking network calls for this component
|
595
|
+
self.end_component(component_id)
|
596
|
+
|
597
|
+
end_time = datetime.now().astimezone()
|
598
|
+
|
599
|
+
name = self.current_llm_call_name.get()
|
600
|
+
if name is None:
|
601
|
+
name = original_func.__name__
|
602
|
+
|
603
|
+
end_memory = psutil.Process().memory_info().rss
|
604
|
+
memory_used = max(0, end_memory - start_memory)
|
605
|
+
|
606
|
+
llm_component = self.create_llm_component(
|
607
|
+
component_id=component_id,
|
608
|
+
hash_id=hash_id,
|
609
|
+
name=name,
|
610
|
+
llm_type="unknown",
|
611
|
+
version="1.0.0",
|
612
|
+
memory_used=memory_used,
|
613
|
+
start_time=start_time,
|
614
|
+
end_time=end_time,
|
615
|
+
input_data=self._extract_input_data(args, kwargs, None),
|
616
|
+
output_data=None,
|
617
|
+
error=error_component
|
618
|
+
)
|
619
|
+
|
620
|
+
self.add_component(llm_component)
|
621
|
+
raise
|
622
|
+
|
623
|
+
def trace_llm(self, name: str = None):
|
624
|
+
def decorator(func):
|
625
|
+
@self.file_tracker.trace_decorator
|
626
|
+
@functools.wraps(func)
|
627
|
+
async def async_wrapper(*args, **kwargs):
|
628
|
+
self.gt = kwargs.get('gt', None) if kwargs else None
|
629
|
+
if not self.is_active:
|
630
|
+
return await func(*args, **kwargs)
|
631
|
+
|
632
|
+
hash_id = generate_unique_hash_simple(func)
|
633
|
+
component_id = str(uuid.uuid4())
|
634
|
+
parent_agent_id = self.current_agent_id.get()
|
635
|
+
self.start_component(component_id)
|
636
|
+
|
637
|
+
start_time = datetime.now()
|
638
|
+
error_info = None
|
639
|
+
result = None
|
640
|
+
|
641
|
+
try:
|
642
|
+
result = await func(*args, **kwargs)
|
643
|
+
return result
|
644
|
+
except Exception as e:
|
645
|
+
error_info = {
|
646
|
+
"error": {
|
647
|
+
"type": type(e).__name__,
|
648
|
+
"message": str(e),
|
649
|
+
"traceback": traceback.format_exc(),
|
650
|
+
"timestamp": datetime.now().isoformat()
|
651
|
+
}
|
652
|
+
}
|
653
|
+
raise
|
654
|
+
finally:
|
655
|
+
|
656
|
+
llm_component = self.llm_data
|
657
|
+
|
658
|
+
if error_info:
|
659
|
+
llm_component["error"] = error_info["error"]
|
660
|
+
|
661
|
+
if parent_agent_id:
|
662
|
+
children = self.agent_children.get()
|
663
|
+
children.append(llm_component)
|
664
|
+
self.agent_children.set(children)
|
665
|
+
else:
|
666
|
+
self.add_component(llm_component)
|
667
|
+
|
668
|
+
self.end_component(component_id)
|
669
|
+
|
670
|
+
@self.file_tracker.trace_decorator
|
671
|
+
@functools.wraps(func)
|
672
|
+
def sync_wrapper(*args, **kwargs):
|
673
|
+
self.gt = kwargs.get('gt', None) if kwargs else None
|
674
|
+
if not self.is_active:
|
675
|
+
return func(*args, **kwargs)
|
676
|
+
|
677
|
+
hash_id = generate_unique_hash_simple(func)
|
678
|
+
|
679
|
+
component_id = str(uuid.uuid4())
|
680
|
+
parent_agent_id = self.current_agent_id.get()
|
681
|
+
self.start_component(component_id)
|
682
|
+
|
683
|
+
start_time = datetime.now()
|
684
|
+
error_info = None
|
685
|
+
result = None
|
686
|
+
|
687
|
+
try:
|
688
|
+
result = func(*args, **kwargs)
|
689
|
+
return result
|
690
|
+
except Exception as e:
|
691
|
+
error_info = {
|
692
|
+
"error": {
|
693
|
+
"type": type(e).__name__,
|
694
|
+
"message": str(e),
|
695
|
+
"traceback": traceback.format_exc(),
|
696
|
+
"timestamp": datetime.now().isoformat()
|
697
|
+
}
|
698
|
+
}
|
699
|
+
raise
|
700
|
+
finally:
|
701
|
+
|
702
|
+
llm_component = self.llm_data
|
703
|
+
|
704
|
+
if error_info:
|
705
|
+
llm_component["error"] = error_info["error"]
|
706
|
+
|
707
|
+
if parent_agent_id:
|
708
|
+
children = self.agent_children.get()
|
709
|
+
children.append(llm_component)
|
710
|
+
self.agent_children.set(children)
|
711
|
+
else:
|
712
|
+
self.add_component(llm_component)
|
713
|
+
|
714
|
+
self.end_component(component_id)
|
715
|
+
|
716
|
+
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
717
|
+
return decorator
|
718
|
+
|
719
|
+
def unpatch_llm_calls(self):
|
720
|
+
# Remove all patches
|
721
|
+
for obj, method_name, original_method in self.patches:
|
722
|
+
try:
|
723
|
+
setattr(obj, method_name, original_method)
|
724
|
+
except Exception as e:
|
725
|
+
print(f"Error unpatching {method_name}: {str(e)}")
|
726
|
+
self.patches = []
|
727
|
+
|
728
|
+
def _sanitize_api_keys(self, data):
|
729
|
+
"""Remove sensitive information from data"""
|
730
|
+
if isinstance(data, dict):
|
731
|
+
return {k: self._sanitize_api_keys(v) for k, v in data.items()
|
732
|
+
if not any(sensitive in k.lower() for sensitive in ['key', 'token', 'secret', 'password'])}
|
733
|
+
elif isinstance(data, list):
|
734
|
+
return [self._sanitize_api_keys(item) for item in data]
|
735
|
+
elif isinstance(data, tuple):
|
736
|
+
return tuple(self._sanitize_api_keys(item) for item in data)
|
737
|
+
return data
|
738
|
+
|
739
|
+
def _sanitize_input(self, args, kwargs):
|
740
|
+
"""Convert input arguments to text format.
|
741
|
+
|
742
|
+
Args:
|
743
|
+
args: Input arguments that may contain nested dictionaries
|
744
|
+
|
745
|
+
Returns:
|
746
|
+
str: Text representation of the input arguments
|
747
|
+
"""
|
748
|
+
if isinstance(args, dict):
|
749
|
+
return str({k: self._sanitize_input(v, {}) for k, v in args.items()})
|
750
|
+
elif isinstance(args, (list, tuple)):
|
751
|
+
return str([self._sanitize_input(item, {}) for item in args])
|
752
|
+
return str(args)
|
753
|
+
|
754
|
+
def extract_llm_output(result):
|
755
|
+
"""Extract output from LLM response"""
|
756
|
+
class OutputResponse:
|
757
|
+
def __init__(self, output_response):
|
758
|
+
self.output_response = output_response
|
759
|
+
|
760
|
+
# Handle coroutines
|
761
|
+
if asyncio.iscoroutine(result):
|
762
|
+
# For sync context, run the coroutine
|
763
|
+
if not asyncio.get_event_loop().is_running():
|
764
|
+
result = asyncio.run(result)
|
765
|
+
else:
|
766
|
+
# We're in an async context, but this function is called synchronously
|
767
|
+
# Return a placeholder and let the caller handle the coroutine
|
768
|
+
return OutputResponse("Coroutine result pending")
|
769
|
+
|
770
|
+
# Handle Google GenerativeAI format
|
771
|
+
if hasattr(result, "result"):
|
772
|
+
candidates = getattr(result.result, "candidates", [])
|
773
|
+
output = []
|
774
|
+
for candidate in candidates:
|
775
|
+
content = getattr(candidate, "content", None)
|
776
|
+
if content and hasattr(content, "parts"):
|
777
|
+
for part in content.parts:
|
778
|
+
if hasattr(part, "text"):
|
779
|
+
output.append({
|
780
|
+
"content": part.text,
|
781
|
+
"role": getattr(content, "role", "assistant"),
|
782
|
+
"finish_reason": getattr(candidate, "finish_reason", None)
|
783
|
+
})
|
784
|
+
return OutputResponse(output)
|
785
|
+
|
786
|
+
# Handle Vertex AI format
|
787
|
+
if hasattr(result, "text"):
|
788
|
+
return OutputResponse([{
|
789
|
+
"content": result.text,
|
790
|
+
"role": "assistant"
|
791
|
+
}])
|
792
|
+
|
793
|
+
# Handle OpenAI format
|
794
|
+
if hasattr(result, "choices"):
|
795
|
+
return OutputResponse([{
|
796
|
+
"content": choice.message.content,
|
797
|
+
"role": choice.message.role
|
798
|
+
} for choice in result.choices])
|
799
|
+
|
800
|
+
# Handle Anthropic format
|
801
|
+
if hasattr(result, "completion"):
|
802
|
+
return OutputResponse([{
|
803
|
+
"content": result.completion,
|
804
|
+
"role": "assistant"
|
805
|
+
}])
|
806
|
+
|
807
|
+
# Default case
|
808
|
+
return OutputResponse(str(result))
|