ragaai-catalyst 2.0.7.2__py3-none-any.whl → 2.0.7.2b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragaai_catalyst/evaluation.py +107 -153
- ragaai_catalyst/tracers/agentic_tracing/Untitled-1.json +660 -0
- ragaai_catalyst/tracers/agentic_tracing/__init__.py +3 -0
- ragaai_catalyst/tracers/agentic_tracing/agent_tracer.py +311 -0
- ragaai_catalyst/tracers/agentic_tracing/agentic_tracing.py +212 -0
- ragaai_catalyst/tracers/agentic_tracing/base.py +270 -0
- ragaai_catalyst/tracers/agentic_tracing/data_structure.py +239 -0
- ragaai_catalyst/tracers/agentic_tracing/llm_tracer.py +906 -0
- ragaai_catalyst/tracers/agentic_tracing/network_tracer.py +286 -0
- ragaai_catalyst/tracers/agentic_tracing/sample.py +197 -0
- ragaai_catalyst/tracers/agentic_tracing/tool_tracer.py +235 -0
- ragaai_catalyst/tracers/agentic_tracing/unique_decorator.py +221 -0
- ragaai_catalyst/tracers/agentic_tracing/unique_decorator_test.py +172 -0
- ragaai_catalyst/tracers/agentic_tracing/user_interaction_tracer.py +67 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/__init__.py +3 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/api_utils.py +18 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/data_classes.py +61 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/generic.py +32 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +181 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +5946 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +74 -0
- ragaai_catalyst/tracers/tracer.py +26 -4
- ragaai_catalyst/tracers/upload_traces.py +127 -0
- ragaai_catalyst-2.0.7.2b0.dist-info/METADATA +39 -0
- ragaai_catalyst-2.0.7.2b0.dist-info/RECORD +50 -0
- ragaai_catalyst-2.0.7.2.dist-info/METADATA +0 -386
- ragaai_catalyst-2.0.7.2.dist-info/RECORD +0 -29
- {ragaai_catalyst-2.0.7.2.dist-info → ragaai_catalyst-2.0.7.2b0.dist-info}/WHEEL +0 -0
- {ragaai_catalyst-2.0.7.2.dist-info → ragaai_catalyst-2.0.7.2b0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,906 @@
|
|
1
|
+
from typing import Optional, Any, Dict, List
|
2
|
+
import asyncio
|
3
|
+
import psutil
|
4
|
+
import json
|
5
|
+
import wrapt
|
6
|
+
import functools
|
7
|
+
from datetime import datetime
|
8
|
+
import uuid
|
9
|
+
import os
|
10
|
+
import contextvars
|
11
|
+
import sys
|
12
|
+
import gc
|
13
|
+
|
14
|
+
from .unique_decorator import mydecorator
|
15
|
+
from .utils.trace_utils import calculate_cost, load_model_costs
|
16
|
+
from .utils.llm_utils import extract_llm_output
|
17
|
+
|
18
|
+
|
19
|
+
class LLMTracerMixin:
|
20
|
+
def __init__(self, *args, **kwargs):
|
21
|
+
super().__init__(*args, **kwargs)
|
22
|
+
self.patches = []
|
23
|
+
try:
|
24
|
+
self.model_costs = load_model_costs()
|
25
|
+
except Exception as e:
|
26
|
+
# If model costs can't be loaded, use default costs
|
27
|
+
self.model_costs = {
|
28
|
+
"default": {
|
29
|
+
"input_cost_per_token": 0.00002,
|
30
|
+
"output_cost_per_token": 0.00002
|
31
|
+
}
|
32
|
+
}
|
33
|
+
self.current_llm_call_name = contextvars.ContextVar("llm_call_name", default=None)
|
34
|
+
self.component_network_calls = {}
|
35
|
+
self.current_component_id = None
|
36
|
+
self.total_tokens = 0
|
37
|
+
self.total_cost = 0.0
|
38
|
+
# Apply decorator to trace_llm_call method
|
39
|
+
self.trace_llm_call = mydecorator(self.trace_llm_call)
|
40
|
+
|
41
|
+
def instrument_llm_calls(self):
|
42
|
+
# Handle modules that are already imported
|
43
|
+
import sys
|
44
|
+
|
45
|
+
if "vertexai" in sys.modules:
|
46
|
+
self.patch_vertex_ai_methods(sys.modules["vertexai"])
|
47
|
+
if "vertexai.generative_models" in sys.modules:
|
48
|
+
self.patch_vertex_ai_methods(sys.modules["vertexai.generative_models"])
|
49
|
+
|
50
|
+
if "openai" in sys.modules:
|
51
|
+
self.patch_openai_methods(sys.modules["openai"])
|
52
|
+
if "litellm" in sys.modules:
|
53
|
+
self.patch_litellm_methods(sys.modules["litellm"])
|
54
|
+
if "anthropic" in sys.modules:
|
55
|
+
self.patch_anthropic_methods(sys.modules["anthropic"])
|
56
|
+
if "google.generativeai" in sys.modules:
|
57
|
+
self.patch_google_genai_methods(sys.modules["google.generativeai"])
|
58
|
+
if "langchain_google_vertexai" in sys.modules:
|
59
|
+
self.patch_langchain_google_methods(sys.modules["langchain_google_vertexai"])
|
60
|
+
if "langchain_google_genai" in sys.modules:
|
61
|
+
self.patch_langchain_google_methods(sys.modules["langchain_google_genai"])
|
62
|
+
|
63
|
+
# Register hooks for future imports
|
64
|
+
wrapt.register_post_import_hook(self.patch_vertex_ai_methods, "vertexai")
|
65
|
+
wrapt.register_post_import_hook(self.patch_vertex_ai_methods, "vertexai.generative_models")
|
66
|
+
wrapt.register_post_import_hook(self.patch_openai_methods, "openai")
|
67
|
+
wrapt.register_post_import_hook(self.patch_litellm_methods, "litellm")
|
68
|
+
wrapt.register_post_import_hook(self.patch_anthropic_methods, "anthropic")
|
69
|
+
wrapt.register_post_import_hook(self.patch_google_genai_methods, "google.generativeai")
|
70
|
+
|
71
|
+
# Add hooks for LangChain integrations
|
72
|
+
wrapt.register_post_import_hook(self.patch_langchain_google_methods, "langchain_google_vertexai")
|
73
|
+
wrapt.register_post_import_hook(self.patch_langchain_google_methods, "langchain_google_genai")
|
74
|
+
|
75
|
+
def patch_openai_methods(self, module):
|
76
|
+
try:
|
77
|
+
if hasattr(module, "OpenAI"):
|
78
|
+
client_class = getattr(module, "OpenAI")
|
79
|
+
self.wrap_openai_client_methods(client_class)
|
80
|
+
if hasattr(module, "AsyncOpenAI"):
|
81
|
+
async_client_class = getattr(module, "AsyncOpenAI")
|
82
|
+
self.wrap_openai_client_methods(async_client_class)
|
83
|
+
except Exception as e:
|
84
|
+
# Log the error but continue execution
|
85
|
+
print(f"Warning: Failed to patch OpenAI methods: {str(e)}")
|
86
|
+
|
87
|
+
def patch_anthropic_methods(self, module):
|
88
|
+
if hasattr(module, "Anthropic"):
|
89
|
+
client_class = getattr(module, "Anthropic")
|
90
|
+
self.wrap_anthropic_client_methods(client_class)
|
91
|
+
|
92
|
+
def patch_google_genai_methods(self, module):
|
93
|
+
# Patch direct Google GenerativeAI usage
|
94
|
+
if hasattr(module, "GenerativeModel"):
|
95
|
+
model_class = getattr(module, "GenerativeModel")
|
96
|
+
self.wrap_genai_model_methods(model_class)
|
97
|
+
|
98
|
+
# Patch LangChain integration
|
99
|
+
if hasattr(module, "ChatGoogleGenerativeAI"):
|
100
|
+
chat_class = getattr(module, "ChatGoogleGenerativeAI")
|
101
|
+
# Wrap invoke method to capture messages
|
102
|
+
original_invoke = chat_class.invoke
|
103
|
+
|
104
|
+
def patched_invoke(self, messages, *args, **kwargs):
|
105
|
+
# Store messages in the instance for later use
|
106
|
+
self._last_messages = messages
|
107
|
+
return original_invoke(self, messages, *args, **kwargs)
|
108
|
+
|
109
|
+
chat_class.invoke = patched_invoke
|
110
|
+
|
111
|
+
# LangChain v0.2+ uses invoke/ainvoke
|
112
|
+
self.wrap_method(chat_class, "_generate")
|
113
|
+
if hasattr(chat_class, "_agenerate"):
|
114
|
+
self.wrap_method(chat_class, "_agenerate")
|
115
|
+
# Fallback for completion methods
|
116
|
+
if hasattr(chat_class, "complete"):
|
117
|
+
self.wrap_method(chat_class, "complete")
|
118
|
+
if hasattr(chat_class, "acomplete"):
|
119
|
+
self.wrap_method(chat_class, "acomplete")
|
120
|
+
|
121
|
+
def patch_vertex_ai_methods(self, module):
|
122
|
+
# Patch the GenerativeModel class
|
123
|
+
if hasattr(module, "generative_models"):
|
124
|
+
gen_models = getattr(module, "generative_models")
|
125
|
+
if hasattr(gen_models, "GenerativeModel"):
|
126
|
+
model_class = getattr(gen_models, "GenerativeModel")
|
127
|
+
self.wrap_vertex_model_methods(model_class)
|
128
|
+
|
129
|
+
# Also patch the class directly if available
|
130
|
+
if hasattr(module, "GenerativeModel"):
|
131
|
+
model_class = getattr(module, "GenerativeModel")
|
132
|
+
self.wrap_vertex_model_methods(model_class)
|
133
|
+
|
134
|
+
def wrap_vertex_model_methods(self, model_class):
|
135
|
+
# Patch both sync and async methods
|
136
|
+
self.wrap_method(model_class, "generate_content")
|
137
|
+
if hasattr(model_class, "generate_content_async"):
|
138
|
+
self.wrap_method(model_class, "generate_content_async")
|
139
|
+
|
140
|
+
def patch_litellm_methods(self, module):
|
141
|
+
self.wrap_method(module, "completion")
|
142
|
+
self.wrap_method(module, "acompletion")
|
143
|
+
|
144
|
+
def patch_langchain_google_methods(self, module):
|
145
|
+
"""Patch LangChain's Google integration methods"""
|
146
|
+
if hasattr(module, "ChatVertexAI"):
|
147
|
+
chat_class = getattr(module, "ChatVertexAI")
|
148
|
+
# LangChain v0.2+ uses invoke/ainvoke
|
149
|
+
self.wrap_method(chat_class, "_generate")
|
150
|
+
if hasattr(chat_class, "_agenerate"):
|
151
|
+
self.wrap_method(chat_class, "_agenerate")
|
152
|
+
# Fallback for completion methods
|
153
|
+
if hasattr(chat_class, "complete"):
|
154
|
+
self.wrap_method(chat_class, "complete")
|
155
|
+
if hasattr(chat_class, "acomplete"):
|
156
|
+
self.wrap_method(chat_class, "acomplete")
|
157
|
+
|
158
|
+
if hasattr(module, "ChatGoogleGenerativeAI"):
|
159
|
+
chat_class = getattr(module, "ChatGoogleGenerativeAI")
|
160
|
+
# LangChain v0.2+ uses invoke/ainvoke
|
161
|
+
self.wrap_method(chat_class, "_generate")
|
162
|
+
if hasattr(chat_class, "_agenerate"):
|
163
|
+
self.wrap_method(chat_class, "_agenerate")
|
164
|
+
# Fallback for completion methods
|
165
|
+
if hasattr(chat_class, "complete"):
|
166
|
+
self.wrap_method(chat_class, "complete")
|
167
|
+
if hasattr(chat_class, "acomplete"):
|
168
|
+
self.wrap_method(chat_class, "acomplete")
|
169
|
+
|
170
|
+
def wrap_openai_client_methods(self, client_class):
|
171
|
+
original_init = client_class.__init__
|
172
|
+
|
173
|
+
@functools.wraps(original_init)
|
174
|
+
def patched_init(client_self, *args, **kwargs):
|
175
|
+
original_init(client_self, *args, **kwargs)
|
176
|
+
self.wrap_method(client_self.chat.completions, "create")
|
177
|
+
if hasattr(client_self.chat.completions, "acreate"):
|
178
|
+
self.wrap_method(client_self.chat.completions, "acreate")
|
179
|
+
|
180
|
+
setattr(client_class, "__init__", patched_init)
|
181
|
+
|
182
|
+
def wrap_anthropic_client_methods(self, client_class):
|
183
|
+
original_init = client_class.__init__
|
184
|
+
|
185
|
+
@functools.wraps(original_init)
|
186
|
+
def patched_init(client_self, *args, **kwargs):
|
187
|
+
original_init(client_self, *args, **kwargs)
|
188
|
+
self.wrap_method(client_self.messages, "create")
|
189
|
+
if hasattr(client_self.messages, "acreate"):
|
190
|
+
self.wrap_method(client_self.messages, "acreate")
|
191
|
+
|
192
|
+
setattr(client_class, "__init__", patched_init)
|
193
|
+
|
194
|
+
def wrap_genai_model_methods(self, model_class):
|
195
|
+
original_init = model_class.__init__
|
196
|
+
|
197
|
+
@functools.wraps(original_init)
|
198
|
+
def patched_init(model_self, *args, **kwargs):
|
199
|
+
original_init(model_self, *args, **kwargs)
|
200
|
+
self.wrap_method(model_self, "generate_content")
|
201
|
+
if hasattr(model_self, "generate_content_async"):
|
202
|
+
self.wrap_method(model_self, "generate_content_async")
|
203
|
+
|
204
|
+
setattr(model_class, "__init__", patched_init)
|
205
|
+
|
206
|
+
def wrap_method(self, obj, method_name):
|
207
|
+
"""
|
208
|
+
Wrap a method with tracing functionality.
|
209
|
+
Works for both class methods and instance methods.
|
210
|
+
"""
|
211
|
+
# If obj is a class, we need to patch both the class and any existing instances
|
212
|
+
if isinstance(obj, type):
|
213
|
+
# Store the original class method
|
214
|
+
original_method = getattr(obj, method_name)
|
215
|
+
|
216
|
+
@wrapt.decorator
|
217
|
+
def wrapper(wrapped, instance, args, kwargs):
|
218
|
+
if asyncio.iscoroutinefunction(wrapped):
|
219
|
+
return self.trace_llm_call(wrapped, *args, **kwargs)
|
220
|
+
return self.trace_llm_call_sync(wrapped, *args, **kwargs)
|
221
|
+
|
222
|
+
# Wrap the class method
|
223
|
+
wrapped_method = wrapper(original_method)
|
224
|
+
setattr(obj, method_name, wrapped_method)
|
225
|
+
self.patches.append((obj, method_name, original_method))
|
226
|
+
|
227
|
+
else:
|
228
|
+
# For instance methods
|
229
|
+
original_method = getattr(obj, method_name)
|
230
|
+
|
231
|
+
@wrapt.decorator
|
232
|
+
def wrapper(wrapped, instance, args, kwargs):
|
233
|
+
if asyncio.iscoroutinefunction(wrapped):
|
234
|
+
return self.trace_llm_call(wrapped, *args, **kwargs)
|
235
|
+
return self.trace_llm_call_sync(wrapped, *args, **kwargs)
|
236
|
+
|
237
|
+
wrapped_method = wrapper(original_method)
|
238
|
+
setattr(obj, method_name, wrapped_method)
|
239
|
+
self.patches.append((obj, method_name, original_method))
|
240
|
+
|
241
|
+
def _extract_model_name(self, kwargs):
|
242
|
+
"""Extract model name from kwargs or result"""
|
243
|
+
# First try direct model parameter
|
244
|
+
model = kwargs.get("model", "")
|
245
|
+
|
246
|
+
if not model:
|
247
|
+
# Try to get from instance
|
248
|
+
instance = kwargs.get("self", None)
|
249
|
+
if instance:
|
250
|
+
# Try model_name first (Google format)
|
251
|
+
if hasattr(instance, "model_name"):
|
252
|
+
model = instance.model_name
|
253
|
+
# Try model attribute
|
254
|
+
elif hasattr(instance, "model"):
|
255
|
+
model = instance.model
|
256
|
+
|
257
|
+
# Normalize Google model names
|
258
|
+
if model and isinstance(model, str):
|
259
|
+
model = model.lower()
|
260
|
+
if "gemini-1.5-flash" in model:
|
261
|
+
return "gemini-1.5-flash"
|
262
|
+
if "gemini-1.5-pro" in model:
|
263
|
+
return "gemini-1.5-pro"
|
264
|
+
if "gemini-pro" in model:
|
265
|
+
return "gemini-pro"
|
266
|
+
|
267
|
+
return model or "default"
|
268
|
+
|
269
|
+
def _extract_parameters(self, kwargs, result=None):
|
270
|
+
"""Extract parameters from kwargs or result"""
|
271
|
+
params = {
|
272
|
+
"temperature": kwargs.get("temperature", getattr(result, "temperature", 0.7)),
|
273
|
+
"top_p": kwargs.get("top_p", getattr(result, "top_p", 1.0)),
|
274
|
+
"max_tokens": kwargs.get("max_tokens", getattr(result, "max_tokens", 512))
|
275
|
+
}
|
276
|
+
|
277
|
+
# Add Google AI specific parameters if available
|
278
|
+
if hasattr(kwargs.get("self", None), "generation_config"):
|
279
|
+
gen_config = kwargs["self"].generation_config
|
280
|
+
params.update({
|
281
|
+
"candidate_count": getattr(gen_config, "candidate_count", 1),
|
282
|
+
"stop_sequences": getattr(gen_config, "stop_sequences", []),
|
283
|
+
"top_k": getattr(gen_config, "top_k", 40)
|
284
|
+
})
|
285
|
+
|
286
|
+
return params
|
287
|
+
|
288
|
+
def _extract_token_usage(self, result):
|
289
|
+
"""Extract token usage from result"""
|
290
|
+
# Handle coroutines
|
291
|
+
if asyncio.iscoroutine(result):
|
292
|
+
result = asyncio.run(result)
|
293
|
+
|
294
|
+
# Handle standard OpenAI/Anthropic format
|
295
|
+
if hasattr(result, "usage"):
|
296
|
+
usage = result.usage
|
297
|
+
return {
|
298
|
+
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
299
|
+
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
300
|
+
"total_tokens": getattr(usage, "total_tokens", 0)
|
301
|
+
}
|
302
|
+
|
303
|
+
# Handle Google GenerativeAI format with usage_metadata
|
304
|
+
if hasattr(result, "usage_metadata"):
|
305
|
+
metadata = result.usage_metadata
|
306
|
+
return {
|
307
|
+
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
|
308
|
+
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
|
309
|
+
"total_tokens": getattr(metadata, "total_token_count", 0)
|
310
|
+
}
|
311
|
+
|
312
|
+
# Handle Vertex AI format
|
313
|
+
if hasattr(result, "text"):
|
314
|
+
# For LangChain ChatVertexAI
|
315
|
+
total_tokens = getattr(result, "token_count", 0)
|
316
|
+
if not total_tokens and hasattr(result, "_raw_response"):
|
317
|
+
# Try to get from raw response
|
318
|
+
total_tokens = getattr(result._raw_response, "token_count", 0)
|
319
|
+
return {
|
320
|
+
"prompt_tokens": 0, # Vertex AI doesn't provide this breakdown
|
321
|
+
"completion_tokens": total_tokens,
|
322
|
+
"total_tokens": total_tokens
|
323
|
+
}
|
324
|
+
|
325
|
+
return {
|
326
|
+
"prompt_tokens": 0,
|
327
|
+
"completion_tokens": 0,
|
328
|
+
"total_tokens": 0
|
329
|
+
}
|
330
|
+
|
331
|
+
def _extract_input_data(self, kwargs, result):
|
332
|
+
"""Extract input data from kwargs and result"""
|
333
|
+
|
334
|
+
# For Vertex AI GenerationResponse
|
335
|
+
if hasattr(result, 'candidates') and hasattr(result, 'usage_metadata'):
|
336
|
+
# Extract generation config
|
337
|
+
generation_config = kwargs.get('generation_config', {})
|
338
|
+
config_dict = {}
|
339
|
+
if hasattr(generation_config, 'temperature'):
|
340
|
+
config_dict['temperature'] = generation_config.temperature
|
341
|
+
if hasattr(generation_config, 'top_p'):
|
342
|
+
config_dict['top_p'] = generation_config.top_p
|
343
|
+
if hasattr(generation_config, 'max_output_tokens'):
|
344
|
+
config_dict['max_tokens'] = generation_config.max_output_tokens
|
345
|
+
if hasattr(generation_config, 'candidate_count'):
|
346
|
+
config_dict['n'] = generation_config.candidate_count
|
347
|
+
|
348
|
+
return {
|
349
|
+
"prompt": kwargs.get('contents', ''),
|
350
|
+
"model": "gemini-1.5-flash-002",
|
351
|
+
**config_dict
|
352
|
+
}
|
353
|
+
|
354
|
+
# For standard OpenAI format
|
355
|
+
messages = kwargs.get("messages", [])
|
356
|
+
if messages:
|
357
|
+
return {
|
358
|
+
"messages": messages,
|
359
|
+
"model": kwargs.get("model", "unknown"),
|
360
|
+
"temperature": kwargs.get("temperature", 0.7),
|
361
|
+
"max_tokens": kwargs.get("max_tokens", None),
|
362
|
+
"top_p": kwargs.get("top_p", None),
|
363
|
+
"frequency_penalty": kwargs.get("frequency_penalty", None),
|
364
|
+
"presence_penalty": kwargs.get("presence_penalty", None)
|
365
|
+
}
|
366
|
+
|
367
|
+
# For text completion format
|
368
|
+
if "prompt" in kwargs:
|
369
|
+
return {
|
370
|
+
"prompt": kwargs["prompt"],
|
371
|
+
"model": kwargs.get("model", "unknown"),
|
372
|
+
"temperature": kwargs.get("temperature", 0.7),
|
373
|
+
"max_tokens": kwargs.get("max_tokens", None),
|
374
|
+
"top_p": kwargs.get("top_p", None),
|
375
|
+
"frequency_penalty": kwargs.get("frequency_penalty", None),
|
376
|
+
"presence_penalty": kwargs.get("presence_penalty", None)
|
377
|
+
}
|
378
|
+
|
379
|
+
# For any other case, try to extract from kwargs
|
380
|
+
if "contents" in kwargs:
|
381
|
+
return {
|
382
|
+
"prompt": kwargs["contents"],
|
383
|
+
"model": kwargs.get("model", "unknown"),
|
384
|
+
"temperature": kwargs.get("temperature", 0.7),
|
385
|
+
"max_tokens": kwargs.get("max_tokens", None),
|
386
|
+
"top_p": kwargs.get("top_p", None)
|
387
|
+
}
|
388
|
+
|
389
|
+
print("No input data found")
|
390
|
+
return {}
|
391
|
+
|
392
|
+
def _calculate_cost(self, token_usage, model_name):
|
393
|
+
"""Calculate cost based on token usage and model"""
|
394
|
+
if not isinstance(token_usage, dict):
|
395
|
+
token_usage = {
|
396
|
+
"prompt_tokens": 0,
|
397
|
+
"completion_tokens": 0,
|
398
|
+
"total_tokens": token_usage if isinstance(token_usage, (int, float)) else 0
|
399
|
+
}
|
400
|
+
|
401
|
+
# Get model costs, defaulting to Vertex AI PaLM2 costs if unknown
|
402
|
+
model_cost = self.model_costs.get(model_name, {
|
403
|
+
"input_cost_per_token": 0.0005, # $0.0005 per 1K input tokens
|
404
|
+
"output_cost_per_token": 0.0005 # $0.0005 per 1K output tokens
|
405
|
+
})
|
406
|
+
|
407
|
+
# Calculate costs per 1K tokens
|
408
|
+
input_cost = (token_usage.get("prompt_tokens", 0) / 1000.0) * model_cost.get("input_cost_per_token", 0.0005)
|
409
|
+
output_cost = (token_usage.get("completion_tokens", 0) / 1000.0) * model_cost.get("output_cost_per_token", 0.0005)
|
410
|
+
total_cost = input_cost + output_cost
|
411
|
+
|
412
|
+
return {
|
413
|
+
"input_cost": round(input_cost, 6),
|
414
|
+
"output_cost": round(output_cost, 6),
|
415
|
+
"total_cost": round(total_cost, 6)
|
416
|
+
}
|
417
|
+
|
418
|
+
def create_llm_component(self, **kwargs):
|
419
|
+
"""Create an LLM component according to the data structure"""
|
420
|
+
start_time = kwargs["start_time"]
|
421
|
+
|
422
|
+
# Ensure cost and usage are dictionaries
|
423
|
+
cost = kwargs.get("cost", {})
|
424
|
+
if not isinstance(cost, dict):
|
425
|
+
cost = {"total_cost": cost}
|
426
|
+
|
427
|
+
usage = kwargs.get("usage", {})
|
428
|
+
if not isinstance(usage, dict):
|
429
|
+
usage = {"total_tokens": usage}
|
430
|
+
|
431
|
+
component = {
|
432
|
+
"id": kwargs["component_id"],
|
433
|
+
"hash_id": kwargs["hash_id"],
|
434
|
+
"source_hash_id": None,
|
435
|
+
"type": "llm",
|
436
|
+
"name": kwargs["name"],
|
437
|
+
"start_time": start_time.isoformat(),
|
438
|
+
"end_time": kwargs["end_time"].isoformat(),
|
439
|
+
"error": kwargs.get("error"),
|
440
|
+
"parent_id": self.current_agent_id.get(),
|
441
|
+
"info": {
|
442
|
+
"llm_type": kwargs.get("llm_type", "unknown"),
|
443
|
+
"version": kwargs.get("version", "1.0.0"),
|
444
|
+
"memory_used": kwargs.get("memory_used", 0),
|
445
|
+
"cost": cost,
|
446
|
+
"tokens": usage
|
447
|
+
},
|
448
|
+
"data": {
|
449
|
+
"input": kwargs.get("input_data"),
|
450
|
+
"output": kwargs.get("output_data"),
|
451
|
+
"memory_used": kwargs.get("memory_used", 0)
|
452
|
+
},
|
453
|
+
"network_calls": self.component_network_calls.get(kwargs["component_id"], []),
|
454
|
+
"interactions": [
|
455
|
+
{
|
456
|
+
"id": f"int_{uuid.uuid4()}",
|
457
|
+
"interaction_type": "input",
|
458
|
+
"timestamp": start_time.isoformat(),
|
459
|
+
"content": kwargs.get("input_data")
|
460
|
+
},
|
461
|
+
{
|
462
|
+
"id": f"int_{uuid.uuid4()}",
|
463
|
+
"interaction_type": "output",
|
464
|
+
"timestamp": kwargs["end_time"].isoformat(),
|
465
|
+
"content": kwargs.get("output_data")
|
466
|
+
}
|
467
|
+
]
|
468
|
+
}
|
469
|
+
return component
|
470
|
+
|
471
|
+
def start_component(self, component_id):
|
472
|
+
"""Start tracking network calls for a component"""
|
473
|
+
self.component_network_calls[component_id] = []
|
474
|
+
self.current_component_id = component_id
|
475
|
+
|
476
|
+
def end_component(self, component_id):
|
477
|
+
"""Stop tracking network calls for a component"""
|
478
|
+
self.current_component_id = None
|
479
|
+
|
480
|
+
|
481
|
+
async def trace_llm_call(self, original_func, *args, **kwargs):
|
482
|
+
"""Trace an LLM API call"""
|
483
|
+
if not self.is_active:
|
484
|
+
if asyncio.iscoroutinefunction(original_func):
|
485
|
+
return await original_func(*args, **kwargs)
|
486
|
+
return original_func(*args, **kwargs)
|
487
|
+
|
488
|
+
start_time = datetime.now().astimezone()
|
489
|
+
start_memory = psutil.Process().memory_info().rss
|
490
|
+
component_id = str(uuid.uuid4())
|
491
|
+
hash_id = self.trace_llm_call.hash_id
|
492
|
+
|
493
|
+
# Start tracking network calls for this component
|
494
|
+
self.start_component(component_id)
|
495
|
+
|
496
|
+
try:
|
497
|
+
# Execute the LLM call
|
498
|
+
result = None
|
499
|
+
if asyncio.iscoroutinefunction(original_func):
|
500
|
+
result = await original_func(*args, **kwargs)
|
501
|
+
else:
|
502
|
+
result = original_func(*args, **kwargs)
|
503
|
+
|
504
|
+
# If result is a coroutine, await it
|
505
|
+
if asyncio.iscoroutine(result):
|
506
|
+
result = await result
|
507
|
+
|
508
|
+
# Calculate resource usage
|
509
|
+
end_time = datetime.now().astimezone()
|
510
|
+
end_memory = psutil.Process().memory_info().rss
|
511
|
+
memory_used = max(0, end_memory - start_memory)
|
512
|
+
|
513
|
+
# Extract token usage and calculate cost
|
514
|
+
token_usage = await self._extract_token_usage(result)
|
515
|
+
model_name = self._extract_model_name(kwargs)
|
516
|
+
cost = self._calculate_cost(token_usage, model_name)
|
517
|
+
|
518
|
+
# End tracking network calls for this component
|
519
|
+
self.end_component(component_id)
|
520
|
+
|
521
|
+
# Create LLM component
|
522
|
+
llm_component = self.create_llm_component(
|
523
|
+
component_id=component_id,
|
524
|
+
hash_id=hash_id,
|
525
|
+
name=self.current_llm_call_name.get(),
|
526
|
+
llm_type=model_name,
|
527
|
+
version="1.0.0",
|
528
|
+
memory_used=memory_used,
|
529
|
+
start_time=start_time,
|
530
|
+
end_time=end_time,
|
531
|
+
input_data=self._extract_input_data(kwargs, result),
|
532
|
+
output_data=extract_llm_output(result),
|
533
|
+
cost=cost,
|
534
|
+
usage=token_usage
|
535
|
+
)
|
536
|
+
|
537
|
+
self.add_component(llm_component)
|
538
|
+
return result
|
539
|
+
|
540
|
+
except Exception as e:
|
541
|
+
error_component = {
|
542
|
+
"code": 500,
|
543
|
+
"type": type(e).__name__,
|
544
|
+
"message": str(e),
|
545
|
+
"details": {}
|
546
|
+
}
|
547
|
+
|
548
|
+
# End tracking network calls for this component
|
549
|
+
self.end_component(component_id)
|
550
|
+
|
551
|
+
end_time = datetime.now().astimezone()
|
552
|
+
|
553
|
+
llm_component = self.create_llm_component(
|
554
|
+
component_id=component_id,
|
555
|
+
hash_id=hash_id,
|
556
|
+
name=self.current_llm_call_name.get(),
|
557
|
+
llm_type="unknown",
|
558
|
+
version="1.0.0",
|
559
|
+
memory_used=0,
|
560
|
+
start_time=start_time,
|
561
|
+
end_time=end_time,
|
562
|
+
input_data=self._extract_input_data(kwargs, None),
|
563
|
+
output_data=None,
|
564
|
+
error=error_component
|
565
|
+
)
|
566
|
+
|
567
|
+
self.add_component(llm_component)
|
568
|
+
raise
|
569
|
+
|
570
|
+
def _extract_token_usage_sync(self, result):
|
571
|
+
"""Sync version of extract token usage"""
|
572
|
+
# Handle coroutines
|
573
|
+
if asyncio.iscoroutine(result):
|
574
|
+
result = asyncio.run(result)
|
575
|
+
|
576
|
+
# Handle standard OpenAI/Anthropic format
|
577
|
+
if hasattr(result, "usage"):
|
578
|
+
usage = result.usage
|
579
|
+
return {
|
580
|
+
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
581
|
+
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
582
|
+
"total_tokens": getattr(usage, "total_tokens", 0)
|
583
|
+
}
|
584
|
+
|
585
|
+
# Handle Google GenerativeAI format with usage_metadata
|
586
|
+
if hasattr(result, "usage_metadata"):
|
587
|
+
metadata = result.usage_metadata
|
588
|
+
return {
|
589
|
+
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
|
590
|
+
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
|
591
|
+
"total_tokens": getattr(metadata, "total_token_count", 0)
|
592
|
+
}
|
593
|
+
|
594
|
+
# Handle Vertex AI format
|
595
|
+
if hasattr(result, "text"):
|
596
|
+
# For LangChain ChatVertexAI
|
597
|
+
total_tokens = getattr(result, "token_count", 0)
|
598
|
+
if not total_tokens and hasattr(result, "_raw_response"):
|
599
|
+
# Try to get from raw response
|
600
|
+
total_tokens = getattr(result._raw_response, "token_count", 0)
|
601
|
+
return {
|
602
|
+
"prompt_tokens": 0, # Vertex AI doesn't provide this breakdown
|
603
|
+
"completion_tokens": total_tokens,
|
604
|
+
"total_tokens": total_tokens
|
605
|
+
}
|
606
|
+
|
607
|
+
return {
|
608
|
+
"prompt_tokens": 0,
|
609
|
+
"completion_tokens": 0,
|
610
|
+
"total_tokens": 0
|
611
|
+
}
|
612
|
+
|
613
|
+
def trace_llm_call_sync(self, original_func, *args, **kwargs):
|
614
|
+
"""Sync version of trace_llm_call"""
|
615
|
+
if not self.is_active:
|
616
|
+
if asyncio.iscoroutinefunction(original_func):
|
617
|
+
return asyncio.run(original_func(*args, **kwargs))
|
618
|
+
return original_func(*args, **kwargs)
|
619
|
+
|
620
|
+
start_time = datetime.now().astimezone()
|
621
|
+
start_memory = psutil.Process().memory_info().rss
|
622
|
+
component_id = str(uuid.uuid4())
|
623
|
+
hash_id = self.trace_llm_call.hash_id
|
624
|
+
|
625
|
+
# Start tracking network calls for this component
|
626
|
+
self.start_component(component_id)
|
627
|
+
|
628
|
+
try:
|
629
|
+
# Execute the LLM call
|
630
|
+
result = None
|
631
|
+
if asyncio.iscoroutinefunction(original_func):
|
632
|
+
result = asyncio.run(original_func(*args, **kwargs))
|
633
|
+
else:
|
634
|
+
result = original_func(*args, **kwargs)
|
635
|
+
|
636
|
+
# If result is a coroutine, run it
|
637
|
+
if asyncio.iscoroutine(result):
|
638
|
+
result = asyncio.run(result)
|
639
|
+
|
640
|
+
# Calculate resource usage
|
641
|
+
end_time = datetime.now().astimezone()
|
642
|
+
end_memory = psutil.Process().memory_info().rss
|
643
|
+
memory_used = max(0, end_memory - start_memory)
|
644
|
+
|
645
|
+
# Extract token usage and calculate cost
|
646
|
+
token_usage = self._extract_token_usage_sync(result)
|
647
|
+
model_name = self._extract_model_name(kwargs)
|
648
|
+
cost = self._calculate_cost(token_usage, model_name)
|
649
|
+
|
650
|
+
# End tracking network calls for this component
|
651
|
+
self.end_component(component_id)
|
652
|
+
|
653
|
+
# Create LLM component
|
654
|
+
llm_component = self.create_llm_component(
|
655
|
+
component_id=component_id,
|
656
|
+
hash_id=hash_id,
|
657
|
+
name=self.current_llm_call_name.get(),
|
658
|
+
llm_type=model_name,
|
659
|
+
version="1.0.0",
|
660
|
+
memory_used=memory_used,
|
661
|
+
start_time=start_time,
|
662
|
+
end_time=end_time,
|
663
|
+
input_data=self._extract_input_data(kwargs, result),
|
664
|
+
output_data=extract_llm_output(result),
|
665
|
+
cost=cost,
|
666
|
+
usage=token_usage
|
667
|
+
)
|
668
|
+
|
669
|
+
self.add_component(llm_component)
|
670
|
+
return result
|
671
|
+
|
672
|
+
except Exception as e:
|
673
|
+
error_component = {
|
674
|
+
"code": 500,
|
675
|
+
"type": type(e).__name__,
|
676
|
+
"message": str(e),
|
677
|
+
"details": {}
|
678
|
+
}
|
679
|
+
|
680
|
+
# End tracking network calls for this component
|
681
|
+
self.end_component(component_id)
|
682
|
+
|
683
|
+
end_time = datetime.now().astimezone()
|
684
|
+
|
685
|
+
llm_component = self.create_llm_component(
|
686
|
+
component_id=component_id,
|
687
|
+
hash_id=hash_id,
|
688
|
+
name=self.current_llm_call_name.get(),
|
689
|
+
llm_type="unknown",
|
690
|
+
version="1.0.0",
|
691
|
+
memory_used=0,
|
692
|
+
start_time=start_time,
|
693
|
+
end_time=end_time,
|
694
|
+
input_data=self._extract_input_data(kwargs, None),
|
695
|
+
output_data=None,
|
696
|
+
error=error_component
|
697
|
+
)
|
698
|
+
|
699
|
+
self.add_component(llm_component)
|
700
|
+
raise
|
701
|
+
|
702
|
+
async def _extract_token_usage(self, result):
|
703
|
+
"""Extract token usage from result"""
|
704
|
+
# Handle coroutines
|
705
|
+
if asyncio.iscoroutine(result):
|
706
|
+
result = await result
|
707
|
+
|
708
|
+
# Handle standard OpenAI/Anthropic format
|
709
|
+
if hasattr(result, "usage"):
|
710
|
+
usage = result.usage
|
711
|
+
return {
|
712
|
+
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
713
|
+
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
714
|
+
"total_tokens": getattr(usage, "total_tokens", 0)
|
715
|
+
}
|
716
|
+
|
717
|
+
# Handle Google GenerativeAI format with usage_metadata
|
718
|
+
if hasattr(result, "usage_metadata"):
|
719
|
+
metadata = result.usage_metadata
|
720
|
+
return {
|
721
|
+
"prompt_tokens": getattr(metadata, "prompt_token_count", 0),
|
722
|
+
"completion_tokens": getattr(metadata, "candidates_token_count", 0),
|
723
|
+
"total_tokens": getattr(metadata, "total_token_count", 0)
|
724
|
+
}
|
725
|
+
|
726
|
+
# Handle Vertex AI format
|
727
|
+
if hasattr(result, "text"):
|
728
|
+
# For LangChain ChatVertexAI
|
729
|
+
total_tokens = getattr(result, "token_count", 0)
|
730
|
+
if not total_tokens and hasattr(result, "_raw_response"):
|
731
|
+
# Try to get from raw response
|
732
|
+
total_tokens = getattr(result._raw_response, "token_count", 0)
|
733
|
+
return {
|
734
|
+
"prompt_tokens": 0, # Vertex AI doesn't provide this breakdown
|
735
|
+
"completion_tokens": total_tokens,
|
736
|
+
"total_tokens": total_tokens
|
737
|
+
}
|
738
|
+
|
739
|
+
return {
|
740
|
+
"prompt_tokens": 0,
|
741
|
+
"completion_tokens": 0,
|
742
|
+
"total_tokens": 0
|
743
|
+
}
|
744
|
+
|
745
|
+
def trace_llm(self, name: str, tool_type: str = "llm", version: str = "1.0.0"):
|
746
|
+
def decorator(func_or_class):
|
747
|
+
if isinstance(func_or_class, type):
|
748
|
+
for attr_name, attr_value in func_or_class.__dict__.items():
|
749
|
+
if callable(attr_value) and not attr_name.startswith("__"):
|
750
|
+
setattr(
|
751
|
+
func_or_class,
|
752
|
+
attr_name,
|
753
|
+
self.trace_llm(f"{name}.{attr_name}", tool_type, version)(attr_value),
|
754
|
+
)
|
755
|
+
return func_or_class
|
756
|
+
else:
|
757
|
+
@functools.wraps(func_or_class)
|
758
|
+
async def async_wrapper(*args, **kwargs):
|
759
|
+
token = self.current_llm_call_name.set(name)
|
760
|
+
try:
|
761
|
+
return await func_or_class(*args, **kwargs)
|
762
|
+
finally:
|
763
|
+
self.current_llm_call_name.reset(token)
|
764
|
+
|
765
|
+
@functools.wraps(func_or_class)
|
766
|
+
def sync_wrapper(*args, **kwargs):
|
767
|
+
token = self.current_llm_call_name.set(name)
|
768
|
+
try:
|
769
|
+
return func_or_class(*args, **kwargs)
|
770
|
+
finally:
|
771
|
+
self.current_llm_call_name.reset(token)
|
772
|
+
|
773
|
+
return async_wrapper if asyncio.iscoroutinefunction(func_or_class) else sync_wrapper
|
774
|
+
|
775
|
+
return decorator
|
776
|
+
|
777
|
+
def unpatch_llm_calls(self):
|
778
|
+
"""Remove all patches"""
|
779
|
+
for obj, method_name, original_method in self.patches:
|
780
|
+
if hasattr(obj, method_name):
|
781
|
+
setattr(obj, method_name, original_method)
|
782
|
+
self.patches.clear()
|
783
|
+
|
784
|
+
def _sanitize_api_keys(self, data):
|
785
|
+
"""Remove sensitive information from data"""
|
786
|
+
if isinstance(data, dict):
|
787
|
+
return {k: self._sanitize_api_keys(v) for k, v in data.items()
|
788
|
+
if not any(sensitive in k.lower() for sensitive in ['key', 'token', 'secret', 'password'])}
|
789
|
+
elif isinstance(data, list):
|
790
|
+
return [self._sanitize_api_keys(item) for item in data]
|
791
|
+
elif isinstance(data, tuple):
|
792
|
+
return tuple(self._sanitize_api_keys(item) for item in data)
|
793
|
+
return data
|
794
|
+
|
795
|
+
def _create_llm_component(self, component_id, hash_id, name, llm_type, version, memory_used, start_time, end_time, input_data, output_data, usage=None, error=None):
|
796
|
+
cost = None
|
797
|
+
tokens = None
|
798
|
+
|
799
|
+
if usage:
|
800
|
+
tokens = {
|
801
|
+
"prompt_tokens": usage.get("prompt_tokens", 0),
|
802
|
+
"completion_tokens": usage.get("completion_tokens", 0),
|
803
|
+
"total_tokens": usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)
|
804
|
+
}
|
805
|
+
cost = calculate_cost(usage)
|
806
|
+
|
807
|
+
# Update total metrics
|
808
|
+
self.total_tokens += tokens["total_tokens"]
|
809
|
+
self.total_cost += cost["total"]
|
810
|
+
|
811
|
+
component = {
|
812
|
+
"id": component_id,
|
813
|
+
"hash_id": hash_id,
|
814
|
+
"source_hash_id": None,
|
815
|
+
"type": "llm",
|
816
|
+
"name": name,
|
817
|
+
"start_time": start_time.isoformat(),
|
818
|
+
"end_time": end_time.isoformat(),
|
819
|
+
"error": error,
|
820
|
+
"parent_id": self.current_agent_id.get(),
|
821
|
+
"info": {
|
822
|
+
"llm_type": llm_type,
|
823
|
+
"version": version,
|
824
|
+
"memory_used": memory_used,
|
825
|
+
"cost": cost,
|
826
|
+
"tokens": tokens
|
827
|
+
},
|
828
|
+
"data": {
|
829
|
+
"input": input_data,
|
830
|
+
"output": output_data.output_response if output_data else None,
|
831
|
+
"memory_used": memory_used
|
832
|
+
},
|
833
|
+
"network_calls": self.component_network_calls.get(component_id, []),
|
834
|
+
"interactions": [
|
835
|
+
{
|
836
|
+
"id": f"int_{uuid.uuid4()}",
|
837
|
+
"interaction_type": "input",
|
838
|
+
"timestamp": start_time.isoformat(),
|
839
|
+
"content": input_data
|
840
|
+
},
|
841
|
+
{
|
842
|
+
"id": f"int_{uuid.uuid4()}",
|
843
|
+
"interaction_type": "output",
|
844
|
+
"timestamp": end_time.isoformat(),
|
845
|
+
"content": output_data.output_response if output_data else None
|
846
|
+
}
|
847
|
+
]
|
848
|
+
}
|
849
|
+
|
850
|
+
return component
|
851
|
+
|
852
|
+
def extract_llm_output(result):
|
853
|
+
"""Extract output from LLM response"""
|
854
|
+
class OutputResponse:
|
855
|
+
def __init__(self, output_response):
|
856
|
+
self.output_response = output_response
|
857
|
+
|
858
|
+
# Handle coroutines
|
859
|
+
if asyncio.iscoroutine(result):
|
860
|
+
# For sync context, run the coroutine
|
861
|
+
if not asyncio.get_event_loop().is_running():
|
862
|
+
result = asyncio.run(result)
|
863
|
+
else:
|
864
|
+
# We're in an async context, but this function is called synchronously
|
865
|
+
# Return a placeholder and let the caller handle the coroutine
|
866
|
+
return OutputResponse("Coroutine result pending")
|
867
|
+
|
868
|
+
# Handle Google GenerativeAI format
|
869
|
+
if hasattr(result, "result"):
|
870
|
+
candidates = getattr(result.result, "candidates", [])
|
871
|
+
output = []
|
872
|
+
for candidate in candidates:
|
873
|
+
content = getattr(candidate, "content", None)
|
874
|
+
if content and hasattr(content, "parts"):
|
875
|
+
for part in content.parts:
|
876
|
+
if hasattr(part, "text"):
|
877
|
+
output.append({
|
878
|
+
"content": part.text,
|
879
|
+
"role": getattr(content, "role", "assistant"),
|
880
|
+
"finish_reason": getattr(candidate, "finish_reason", None)
|
881
|
+
})
|
882
|
+
return OutputResponse(output)
|
883
|
+
|
884
|
+
# Handle Vertex AI format
|
885
|
+
if hasattr(result, "text"):
|
886
|
+
return OutputResponse([{
|
887
|
+
"content": result.text,
|
888
|
+
"role": "assistant"
|
889
|
+
}])
|
890
|
+
|
891
|
+
# Handle OpenAI format
|
892
|
+
if hasattr(result, "choices"):
|
893
|
+
return OutputResponse([{
|
894
|
+
"content": choice.message.content,
|
895
|
+
"role": choice.message.role
|
896
|
+
} for choice in result.choices])
|
897
|
+
|
898
|
+
# Handle Anthropic format
|
899
|
+
if hasattr(result, "completion"):
|
900
|
+
return OutputResponse([{
|
901
|
+
"content": result.completion,
|
902
|
+
"role": "assistant"
|
903
|
+
}])
|
904
|
+
|
905
|
+
# Default case
|
906
|
+
return OutputResponse(str(result))
|