ragaai-catalyst 2.1.5b21__py3-none-any.whl → 2.1.5b22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,9 @@ import markdown
8
8
  import pandas as pd
9
9
  import json
10
10
  from litellm import completion
11
+ import litellm
11
12
  from tqdm import tqdm
13
+ import tiktoken
12
14
  # import internal_api_completion
13
15
  # import proxy_call
14
16
  from .internal_api_completion import api_completion as internal_api_completion
@@ -48,13 +50,18 @@ class SyntheticDataGeneration:
48
50
  Raises:
49
51
  ValueError: If an invalid provider is specified or API key is missing.
50
52
  """
53
+ text_validity = self.validate_input(text)
54
+ if text_validity:
55
+ raise ValueError(text_validity)
56
+
51
57
  BATCH_SIZE = 5 # Optimal batch size for maintaining response quality
52
58
  provider = model_config.get("provider")
53
59
  model = model_config.get("model")
54
60
  api_base = model_config.get("api_base")
61
+ api_version = model_config.get("api_version")
55
62
 
56
63
  # Initialize the appropriate client based on provider
57
- self._initialize_client(provider, api_key, api_base, internal_llm_proxy=kwargs.get("internal_llm_proxy", None))
64
+ self._initialize_client(provider, api_key, api_base, api_version, internal_llm_proxy=kwargs.get("internal_llm_proxy", None))
58
65
 
59
66
  # Initialize progress bar
60
67
  pbar = tqdm(total=n, desc="Generating QA pairs")
@@ -88,7 +95,7 @@ class SyntheticDataGeneration:
88
95
  pbar.update(len(batch_df))
89
96
 
90
97
  except Exception as e:
91
- print(f"Batch generation failed.")
98
+ print(f"Batch generation failed:{str(e)}")
92
99
 
93
100
  if any(error in str(e) for error in FAILURE_CASES):
94
101
  raise Exception(f"{e}")
@@ -139,7 +146,7 @@ class SyntheticDataGeneration:
139
146
 
140
147
  return final_df
141
148
 
142
- def _initialize_client(self, provider, api_key, api_base=None, internal_llm_proxy=None):
149
+ def _initialize_client(self, provider, api_key, api_base=None, api_version=None, internal_llm_proxy=None):
143
150
  """Initialize the appropriate client based on provider."""
144
151
  if not provider:
145
152
  raise ValueError("Model configuration must be provided with a valid provider and model.")
@@ -158,7 +165,17 @@ class SyntheticDataGeneration:
158
165
  if api_key is None and os.getenv("OPENAI_API_KEY") is None and internal_llm_proxy is None:
159
166
  raise ValueError("API key must be provided for OpenAI.")
160
167
  openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
161
-
168
+
169
+ elif provider == "azure":
170
+ if api_key is None and os.getenv("AZURE_API_KEY") is None and internal_llm_proxy is None:
171
+ raise ValueError("API key must be provided for Azure.")
172
+ litellm.api_key = api_key or os.getenv("AZURE_API_KEY")
173
+ if api_base is None and os.getenv("AZURE_API_BASE") is None and internal_llm_proxy is None:
174
+ raise ValueError("API Base must be provided for Azure.")
175
+ litellm.api_base = api_base or os.getenv("AZURE_API_BASE")
176
+ if api_version is None and os.getenv("AZURE_API_VERSION") is None and internal_llm_proxy is None:
177
+ raise ValueError("API version must be provided for Azure.")
178
+ litellm.api_version = api_version or os.getenv("AZURE_API_VERSION")
162
179
  else:
163
180
  raise ValueError(f"Provider is not recognized.")
164
181
 
@@ -189,7 +206,15 @@ class SyntheticDataGeneration:
189
206
  kwargs=kwargs
190
207
  )
191
208
 
209
+ def validate_input(self,text):
192
210
 
211
+ if not text.strip():
212
+ return 'Empty Text provided for qna generation. Please provide valid text'
213
+ encoding = tiktoken.encoding_for_model("gpt-4")
214
+ tokens = encoding.encode(text)
215
+ if len(tokens)<5:
216
+ return 'Very Small Text provided for qna generation. Please provide longer text'
217
+ return False
193
218
 
194
219
 
195
220
  def _get_system_message(self, question_type, n):
@@ -274,10 +299,14 @@ class SyntheticDataGeneration:
274
299
  # Add optional parameters if they exist in model_config
275
300
  if "api_base" in model_config:
276
301
  completion_params["api_base"] = model_config["api_base"]
302
+ if "api_version" in model_config:
303
+ completion_params["api_version"] = model_config["api_version"]
277
304
  if "max_tokens" in model_config:
278
305
  completion_params["max_tokens"] = model_config["max_tokens"]
279
306
  if "temperature" in model_config:
280
307
  completion_params["temperature"] = model_config["temperature"]
308
+ if 'provider' in model_config:
309
+ completion_params['model'] = f'{model_config["provider"]}/{model_config["model"]}'
281
310
 
282
311
  # Make the API call using LiteLLM
283
312
  try:
@@ -318,9 +347,13 @@ class SyntheticDataGeneration:
318
347
  list_start_index = data.find('[') # Find the index of the first '['
319
348
  substring_data = data[list_start_index:] if list_start_index != -1 else data # Slice from the list start
320
349
  data = substring_data
321
-
350
+ elif provider == "azure":
351
+ data = response.choices[0].message.content.replace('\n', '')
352
+ list_start_index = data.find('[') # Find the index of the first '['
353
+ substring_data = data[list_start_index:] if list_start_index != -1 else data # Slice from the list start
354
+ data = substring_data
322
355
  else:
323
- raise ValueError("Invalid provider. Choose 'groq', 'gemini', or 'openai'.")
356
+ raise ValueError("Invalid provider. Choose 'groq', 'gemini', 'azure' or 'openai'.")
324
357
  try:
325
358
  json_data = json.loads(data)
326
359
  return pd.DataFrame(json_data)
@@ -101,7 +101,10 @@ class AgentTracerMixin:
101
101
  original_init = target.__init__
102
102
 
103
103
  def wrapped_init(self, *args, **kwargs):
104
- self.gt = kwargs.get("gt", None) if kwargs else None
104
+ gt = kwargs.get("gt") if kwargs else None
105
+ if gt is not None:
106
+ span = self.span(name)
107
+ span.add_gt(gt)
105
108
  # Set agent context before initializing
106
109
  component_id = str(uuid.uuid4())
107
110
  hash_id = top_level_hash_id
@@ -159,7 +162,10 @@ class AgentTracerMixin:
159
162
  @self.file_tracker.trace_decorator
160
163
  @functools.wraps(method)
161
164
  def wrapped_method(self, *args, **kwargs):
162
- self.gt = kwargs.get("gt", None) if kwargs else None
165
+ gt = kwargs.get("gt") if kwargs else None
166
+ if gt is not None:
167
+ span = tracer.span(name)
168
+ span.add_gt(gt)
163
169
  # Set this agent as current during method execution
164
170
  token = tracer.current_agent_id.set(
165
171
  self._agent_component_id
@@ -247,6 +253,7 @@ class AgentTracerMixin:
247
253
  agent_type,
248
254
  version,
249
255
  capabilities,
256
+ top_level_hash_id,
250
257
  *args,
251
258
  **kwargs,
252
259
  )
@@ -256,10 +263,9 @@ class AgentTracerMixin:
256
263
  return decorator
257
264
 
258
265
  def _trace_sync_agent_execution(
259
- self, func, name, agent_type, version, capabilities, *args, **kwargs
266
+ self, func, name, agent_type, version, capabilities, top_level_hash_id, *args, **kwargs
260
267
  ):
261
- # Generate a unique hash_id for this execution context
262
- hash_id = str(uuid.uuid4())
268
+ hash_id = top_level_hash_id
263
269
 
264
270
  """Synchronous version of agent tracing"""
265
271
  if not self.is_active:
@@ -275,7 +281,10 @@ class AgentTracerMixin:
275
281
  component_id = str(uuid.uuid4())
276
282
 
277
283
  # Extract ground truth if present
278
- ground_truth = kwargs.pop("gt", None) if kwargs else None
284
+ ground_truth = kwargs.pop("gt") if kwargs else None
285
+ if ground_truth is not None:
286
+ span = self.span(name)
287
+ span.add_gt(ground_truth)
279
288
 
280
289
  # Get parent agent ID if exists
281
290
  parent_agent_id = self.current_agent_id.get()
@@ -293,7 +302,7 @@ class AgentTracerMixin:
293
302
 
294
303
  try:
295
304
  # Execute the agent
296
- result = func(*args, **kwargs)
305
+ result = self.file_tracker.trace_wrapper(func)(*args, **kwargs)
297
306
 
298
307
  # Calculate resource usage
299
308
  end_memory = psutil.Process().memory_info().rss
@@ -320,9 +329,6 @@ class AgentTracerMixin:
320
329
  children=children,
321
330
  parent_id=parent_agent_id,
322
331
  )
323
- # Add ground truth to component data if present
324
- if ground_truth is not None:
325
- agent_component["data"]["gt"] = ground_truth
326
332
 
327
333
  # Add this component as a child to parent's children list
328
334
  parent_children.append(agent_component)
@@ -398,7 +404,10 @@ class AgentTracerMixin:
398
404
  component_id = str(uuid.uuid4())
399
405
 
400
406
  # Extract ground truth if present
401
- ground_truth = kwargs.pop("gt", None) if kwargs else None
407
+ ground_truth = kwargs.pop("gt") if kwargs else None
408
+ if ground_truth is not None:
409
+ span = self.span(name)
410
+ span.add_gt(ground_truth)
402
411
 
403
412
  # Get parent agent ID if exists
404
413
  parent_agent_id = self.current_agent_id.get()
@@ -414,7 +423,7 @@ class AgentTracerMixin:
414
423
 
415
424
  try:
416
425
  # Execute the agent
417
- result = await func(*args, **kwargs)
426
+ result = await self.file_tracker.trace_wrapper(func)(*args, **kwargs)
418
427
 
419
428
  # Calculate resource usage
420
429
  end_memory = psutil.Process().memory_info().rss
@@ -441,10 +450,6 @@ class AgentTracerMixin:
441
450
  parent_id=parent_agent_id,
442
451
  )
443
452
 
444
- # Add ground truth to component data if present
445
- if ground_truth is not None:
446
- agent_component["data"]["gt"] = ground_truth
447
-
448
453
  # Add this component as a child to parent's children list
449
454
  parent_children.append(agent_component)
450
455
  self.agent_children.set(parent_children)
@@ -576,8 +581,13 @@ class AgentTracerMixin:
576
581
  "interactions": interactions,
577
582
  }
578
583
 
579
- if self.gt:
580
- component["data"]["gt"] = self.gt
584
+ if name in self.span_attributes_dict:
585
+ span_gt = self.span_attributes_dict[name].gt
586
+ if span_gt is not None:
587
+ component["data"]["gt"] = span_gt
588
+ span_context = self.span_attributes_dict[name].context
589
+ if span_context:
590
+ component["data"]["context"] = span_context
581
591
 
582
592
  # Reset the SpanAttributes context variable
583
593
  self.span_attributes_dict[kwargs["name"]] = SpanAttributes(kwargs["name"])
@@ -83,6 +83,7 @@ class BaseTracer:
83
83
  self.tracking_thread = None
84
84
  self.tracking = False
85
85
  self.system_monitor = None
86
+ self.gt = None
86
87
 
87
88
  def _get_system_info(self) -> SystemInfo:
88
89
  return self.system_monitor.get_system_info()
@@ -249,7 +250,8 @@ class BaseTracer:
249
250
 
250
251
  # Format interactions and add to trace
251
252
  interactions = self.format_interactions()
252
- trace_data["workflow"] = interactions["workflow"]
253
+ # trace_data["workflow"] = interactions["workflow"]
254
+ cleaned_trace_data["workflow"] = interactions["workflow"]
253
255
 
254
256
  with open(filepath, "w") as f:
255
257
  json.dump(cleaned_trace_data, f, cls=TracerJSONEncoder, indent=2)
@@ -45,7 +45,10 @@ class CustomTracerMixin:
45
45
  @functools.wraps(func)
46
46
  async def async_wrapper(*args, **kwargs):
47
47
  async_wrapper.metadata = metadata
48
- self.gt = kwargs.get('gt', None) if kwargs else None
48
+ gt = kwargs.get('gt') if kwargs else None
49
+ if gt is not None:
50
+ span = self.span(name)
51
+ span.add_gt(gt)
49
52
  return await self._trace_custom_execution(
50
53
  func, name or func.__name__, custom_type, version, trace_variables, *args, **kwargs
51
54
  )
@@ -54,7 +57,10 @@ class CustomTracerMixin:
54
57
  @functools.wraps(func)
55
58
  def sync_wrapper(*args, **kwargs):
56
59
  sync_wrapper.metadata = metadata
57
- self.gt = kwargs.get('gt', None) if kwargs else None
60
+ gt = kwargs.get('gt') if kwargs else None
61
+ if gt is not None:
62
+ span = self.span(name)
63
+ span.add_gt(gt)
58
64
  return self._trace_sync_custom_execution(
59
65
  func, name or func.__name__, custom_type, version, trace_variables, *args, **kwargs
60
66
  )
@@ -98,7 +104,7 @@ class CustomTracerMixin:
98
104
 
99
105
  try:
100
106
  # Execute the function
101
- result = func(*args, **kwargs)
107
+ result = self.file_tracker.trace_wrapper(func)(*args, **kwargs)
102
108
 
103
109
  # Calculate resource usage
104
110
  end_time = datetime.now().astimezone().isoformat()
@@ -186,7 +192,7 @@ class CustomTracerMixin:
186
192
 
187
193
  try:
188
194
  # Execute the function
189
- result = await func(*args, **kwargs)
195
+ result = await self.file_tracker.trace_wrapper(func)(*args, **kwargs)
190
196
 
191
197
  # Calculate resource usage
192
198
  end_time = datetime.now().astimezone().isoformat()
@@ -284,9 +290,13 @@ class CustomTracerMixin:
284
290
  "interactions": interactions
285
291
  }
286
292
 
287
- if self.gt:
288
- component["data"]["gt"] = self.gt
289
-
293
+ if kwargs["name"] in self.span_attributes_dict:
294
+ span_gt = self.span_attributes_dict[kwargs["name"]].gt
295
+ if span_gt is not None:
296
+ component["data"]["gt"] = span_gt
297
+ span_context = self.span_attributes_dict[kwargs["name"]].context
298
+ if span_context:
299
+ component["data"]["context"] = span_context
290
300
  return component
291
301
 
292
302
  def start_component(self, component_id):
@@ -12,6 +12,7 @@ import contextvars
12
12
  import traceback
13
13
  import importlib
14
14
  import sys
15
+ from litellm import model_cost
15
16
 
16
17
  from ..utils.llm_utils import (
17
18
  extract_model_name,
@@ -24,8 +25,7 @@ from ..utils.llm_utils import (
24
25
  extract_llm_output,
25
26
  num_tokens_from_messages
26
27
  )
27
- from ..utils.trace_utils import load_model_costs
28
- from ..utils.unique_decorator import generate_unique_hash_simple
28
+ from ..utils.unique_decorator import generate_unique_hash
29
29
  from ..utils.file_name_tracker import TrackName
30
30
  from ..utils.span_attributes import SpanAttributes
31
31
  import logging
@@ -44,7 +44,7 @@ class LLMTracerMixin:
44
44
  self.file_tracker = TrackName()
45
45
  self.patches = []
46
46
  try:
47
- self.model_costs = load_model_costs()
47
+ self.model_costs = model_cost
48
48
  except Exception as e:
49
49
  self.model_costs = {
50
50
  "default": {"input_cost_per_token": 0.0, "output_cost_per_token": 0.0}
@@ -97,6 +97,10 @@ class LLMTracerMixin:
97
97
  self.patch_langchain_google_methods(sys.modules["langchain_google_vertexai"])
98
98
  if "langchain_google_genai" in sys.modules:
99
99
  self.patch_langchain_google_methods(sys.modules["langchain_google_genai"])
100
+ if "langchain_openai" in sys.modules:
101
+ self.patch_langchain_openai_methods(sys.modules["langchain_openai"])
102
+ if "langchain_anthropic" in sys.modules:
103
+ self.patch_langchain_anthropic_methods(sys.modules["langchain_anthropic"])
100
104
 
101
105
  # Register hooks for future imports with availability checks
102
106
  if self.check_package_available("vertexai"):
@@ -130,6 +134,15 @@ class LLMTracerMixin:
130
134
  wrapt.register_post_import_hook(
131
135
  self.patch_langchain_google_methods, "langchain_google_genai"
132
136
  )
137
+
138
+ if self.check_package_available("langchain_openai"):
139
+ wrapt.register_post_import_hook(
140
+ self.patch_langchain_openai_methods, "langchain_openai"
141
+ )
142
+ if self.check_package_available("langchain_anthropic"):
143
+ wrapt.register_post_import_hook(
144
+ self.patch_langchain_anthropic_methods, "langchain_anthropic"
145
+ )
133
146
 
134
147
  def instrument_user_interaction_calls(self):
135
148
  """Enable user interaction instrumentation for LLM calls"""
@@ -154,6 +167,42 @@ class LLMTracerMixin:
154
167
  except Exception as e:
155
168
  # Log the error but continue execution
156
169
  print(f"Warning: Failed to patch OpenAI methods: {str(e)}")
170
+
171
+ def patch_langchain_openai_methods(self, module):
172
+ try:
173
+ if hasattr(module, 'ChatOpenAI'):
174
+ client_class = getattr(module, "ChatOpenAI")
175
+
176
+ if hasattr(client_class, "invoke"):
177
+ self.wrap_langchain_openai_method(client_class, f"{client_class.__name__}.invoke")
178
+ elif hasattr(client_class, "run"):
179
+ self.wrap_langchain_openai_method(client_class, f"{client_class.__name__}.run")
180
+ if hasattr(module, 'AsyncChatOpenAI'):
181
+ if hasattr(client_class, "ainvoke"):
182
+ self.wrap_langchain_openai_method(client_class, f"{client_class.__name__}.ainvoke")
183
+ elif hasattr(client_class, "arun"):
184
+ self.wrap_langchain_openai_method(client_class, f"{client_class.__name__}.arun")
185
+ except Exception as e:
186
+ # Log the error but continue execution
187
+ print(f"Warning: Failed to patch OpenAI methods: {str(e)}")
188
+
189
+ def patch_langchain_anthropic_methods(self, module):
190
+ try:
191
+ if hasattr(module, 'ChatAnthropic'):
192
+ client_class = getattr(module, "ChatAnthropic")
193
+ if hasattr(client_class, "invoke"):
194
+ self.wrap_langchain_anthropic_method(client_class, f"{client_class.__name__}.invoke")
195
+ if hasattr(client_class, "ainvoke"):
196
+ self.wrap_langchain_anthropic_method(client_class, f"{client_class.__name__}.ainvoke")
197
+ if hasattr(module, 'AsyncChatAnthropic'):
198
+ async_client_class = getattr(module, "AsyncChatAnthropic")
199
+ if hasattr(async_client_class, "ainvoke"):
200
+ self.wrap_langchain_anthropic_method(async_client_class, f"{async_client_class.__name__}.ainvoke")
201
+ if hasattr(async_client_class, "arun"):
202
+ self.wrap_langchain_anthropic_method(async_client_class, f"{async_client_class.__name__}.arun")
203
+ except Exception as e:
204
+ # Log the error but continue execution
205
+ print(f"Warning: Failed to patch Anthropic methods: {str(e)}")
157
206
 
158
207
  def patch_openai_beta_methods(self, openai_module):
159
208
  """
@@ -293,7 +342,6 @@ class LLMTracerMixin:
293
342
  return await self.trace_llm_call(
294
343
  original_create, *args, **kwargs
295
344
  )
296
-
297
345
  client_self.chat.completions.create = wrapped_create
298
346
  else:
299
347
  # Patch sync methods for OpenAI
@@ -305,10 +353,39 @@ class LLMTracerMixin:
305
353
  return self.trace_llm_call_sync(
306
354
  original_create, *args, **kwargs
307
355
  )
308
-
309
356
  client_self.chat.completions.create = wrapped_create
310
357
 
311
358
  setattr(client_class, "__init__", patched_init)
359
+
360
+ def wrap_langchain_openai_method(self, client_class, method_name):
361
+ method = method_name.split(".")[-1]
362
+ original_init = getattr(client_class, method)
363
+
364
+ @functools.wraps(original_init)
365
+ def patched_init(*args, **kwargs):
366
+ # Check if this is AsyncOpenAI or OpenAI
367
+ is_async = "AsyncChatOpenAI" in client_class.__name__
368
+
369
+ if is_async:
370
+ return self.trace_llm_call(original_init, *args, **kwargs)
371
+ else:
372
+ return self.trace_llm_call_sync(original_init, *args, **kwargs)
373
+
374
+ setattr(client_class, method, patched_init)
375
+
376
+ def wrap_langchain_anthropic_method(self, client_class, method_name):
377
+ original_init = getattr(client_class, method_name)
378
+
379
+ @functools.wraps(original_init)
380
+ def patched_init(*args, **kwargs):
381
+ is_async = "AsyncChatAnthropic" in client_class.__name__
382
+
383
+ if is_async:
384
+ return self.trace_llm_call(original_init, *args, **kwargs)
385
+ else:
386
+ return self.trace_llm_call_sync(original_init, *args, **kwargs)
387
+
388
+ setattr(client_class, method_name, patched_init)
312
389
 
313
390
  def wrap_anthropic_client_methods(self, client_class):
314
391
  original_init = client_class.__init__
@@ -475,8 +552,13 @@ class LLMTracerMixin:
475
552
  "interactions": interactions,
476
553
  }
477
554
 
478
- if self.gt:
479
- component["data"]["gt"] = self.gt
555
+ if name in self.span_attributes_dict:
556
+ span_gt = self.span_attributes_dict[name].gt
557
+ if span_gt is not None:
558
+ component["data"]["gt"] = span_gt
559
+ span_context = self.span_attributes_dict[name].context
560
+ if span_context:
561
+ component["data"]["context"] = span_context
480
562
 
481
563
  # Reset the SpanAttributes context variable
482
564
  self.span_attributes_dict[name] = SpanAttributes(name)
@@ -503,7 +585,7 @@ class LLMTracerMixin:
503
585
  start_time = datetime.now().astimezone()
504
586
  start_memory = psutil.Process().memory_info().rss
505
587
  component_id = str(uuid.uuid4())
506
- hash_id = generate_unique_hash_simple(original_func)
588
+ hash_id = generate_unique_hash(original_func, args, kwargs)
507
589
 
508
590
  # Start tracking network calls for this component
509
591
  self.start_component(component_id)
@@ -605,7 +687,7 @@ class LLMTracerMixin:
605
687
 
606
688
  start_time = datetime.now().astimezone()
607
689
  component_id = str(uuid.uuid4())
608
- hash_id = generate_unique_hash_simple(original_func)
690
+ hash_id = generate_unique_hash(original_func, args, kwargs)
609
691
 
610
692
  # Start tracking network calls for this component
611
693
  self.start_component(component_id)
@@ -745,12 +827,14 @@ class LLMTracerMixin:
745
827
  @self.file_tracker.trace_decorator
746
828
  @functools.wraps(func)
747
829
  async def async_wrapper(*args, **kwargs):
748
- self.gt = kwargs.get("gt", None) if kwargs else None
830
+ gt = kwargs.get("gt") if kwargs else None
831
+ if gt is not None:
832
+ span = self.span(name)
833
+ span.add_gt(gt)
749
834
  self.current_llm_call_name.set(name)
750
835
  if not self.is_active:
751
836
  return await func(*args, **kwargs)
752
837
 
753
- hash_id = generate_unique_hash_simple(func)
754
838
  component_id = str(uuid.uuid4())
755
839
  parent_agent_id = self.current_agent_id.get()
756
840
  self.start_component(component_id)
@@ -777,8 +861,13 @@ class LLMTracerMixin:
777
861
  if (name is not None) or (name != ""):
778
862
  llm_component["name"] = name
779
863
 
780
- if self.gt:
781
- llm_component["data"]["gt"] = self.gt
864
+ if name in self.span_attributes_dict:
865
+ span_gt = self.span_attributes_dict[name].gt
866
+ if span_gt is not None:
867
+ llm_component["data"]["gt"] = span_gt
868
+ span_context = self.span_attributes_dict[name].context
869
+ if span_context:
870
+ llm_component["data"]["context"] = span_context
782
871
 
783
872
  if error_info:
784
873
  llm_component["error"] = error_info["error"]
@@ -811,13 +900,14 @@ class LLMTracerMixin:
811
900
  @self.file_tracker.trace_decorator
812
901
  @functools.wraps(func)
813
902
  def sync_wrapper(*args, **kwargs):
814
- self.gt = kwargs.get("gt", None) if kwargs else None
903
+ gt = kwargs.get("gt") if kwargs else None
904
+ if gt is not None:
905
+ span = self.span(name)
906
+ span.add_gt(gt)
815
907
  self.current_llm_call_name.set(name)
816
908
  if not self.is_active:
817
909
  return func(*args, **kwargs)
818
910
 
819
- hash_id = generate_unique_hash_simple(func)
820
-
821
911
  component_id = str(uuid.uuid4())
822
912
  parent_agent_id = self.current_agent_id.get()
823
913
  self.start_component(component_id)
@@ -258,7 +258,10 @@ class ToolTracerMixin:
258
258
  @functools.wraps(func)
259
259
  async def async_wrapper(*args, **kwargs):
260
260
  async_wrapper.metadata = metadata
261
- self.gt = kwargs.get("gt", None) if kwargs else None
261
+ gt = kwargs.get("gt") if kwargs else None
262
+ if gt is not None:
263
+ span = self.span(name)
264
+ span.add_gt(gt)
262
265
  return await self._trace_tool_execution(
263
266
  func, name, tool_type, version, *args, **kwargs
264
267
  )
@@ -267,7 +270,10 @@ class ToolTracerMixin:
267
270
  @functools.wraps(func)
268
271
  def sync_wrapper(*args, **kwargs):
269
272
  sync_wrapper.metadata = metadata
270
- self.gt = kwargs.get("gt", None) if kwargs else None
273
+ gt = kwargs.get("gt") if kwargs else None
274
+ if gt is not None:
275
+ span = self.span(name)
276
+ span.add_gt(gt)
271
277
  return self._trace_sync_tool_execution(
272
278
  func, name, tool_type, version, *args, **kwargs
273
279
  )
@@ -302,7 +308,7 @@ class ToolTracerMixin:
302
308
 
303
309
  try:
304
310
  # Execute the tool
305
- result = func(*args, **kwargs)
311
+ result = self.file_tracker.trace_wrapper(func)(*args, **kwargs)
306
312
 
307
313
  # Calculate resource usage
308
314
  end_memory = psutil.Process().memory_info().rss
@@ -384,7 +390,7 @@ class ToolTracerMixin:
384
390
  self.start_component(component_id)
385
391
  try:
386
392
  # Execute the tool
387
- result = await func(*args, **kwargs)
393
+ result = await self.file_tracker.trace_wrapper(func)(*args, **kwargs)
388
394
 
389
395
  # Calculate resource usage
390
396
  end_memory = psutil.Process().memory_info().rss
@@ -504,8 +510,13 @@ class ToolTracerMixin:
504
510
  "interactions": interactions,
505
511
  }
506
512
 
507
- if self.gt:
508
- component["data"]["gt"] = self.gt
513
+ if name in self.span_attributes_dict:
514
+ span_gt = self.span_attributes_dict[name].gt
515
+ if span_gt is not None:
516
+ component["data"]["gt"] = span_gt
517
+ span_context = self.span_attributes_dict[name].context
518
+ if span_context:
519
+ component["data"]["context"] = span_context
509
520
 
510
521
  # Reset the SpanAttributes context variable
511
522
  self.span_attributes_dict[kwargs["name"]] = SpanAttributes(kwargs["name"])
@@ -8,6 +8,7 @@ def upload_trace_metric(json_file_path, dataset_name, project_name):
8
8
  try:
9
9
  with open(json_file_path, "r") as f:
10
10
  traces = json.load(f)
11
+
11
12
  metrics = get_trace_metrics_from_trace(traces)
12
13
  metrics = _change_metrics_format_for_payload(metrics)
13
14
 
@@ -21,7 +22,6 @@ def upload_trace_metric(json_file_path, dataset_name, project_name):
21
22
  metricConfig = next((user_metric["metricConfig"] for user_metric in user_trace_metrics if user_metric["displayName"] == metric["displayName"]), None)
22
23
  if not metricConfig or metricConfig.get("Metric Source", {}).get("value") != "user":
23
24
  raise ValueError(f"Metrics {metric['displayName']} already exist in dataset {dataset_name} of project {project_name}.")
24
-
25
25
  headers = {
26
26
  "Content-Type": "application/json",
27
27
  "Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
@@ -68,13 +68,14 @@ def get_trace_metrics_from_trace(traces):
68
68
  # get span level metrics
69
69
  for span in traces["data"][0]["spans"]:
70
70
  if span["type"] == "agent":
71
+ # Add children metrics of agent
71
72
  children_metric = _get_children_metrics_of_agent(span["data"]["children"])
72
73
  if children_metric:
73
74
  metrics.extend(children_metric)
74
- else:
75
- metric = span.get("metrics", [])
76
- if metric:
77
- metrics.extend(metric)
75
+
76
+ metric = span.get("metrics", [])
77
+ if metric:
78
+ metrics.extend(metric)
78
79
  return metrics
79
80
 
80
81
  def _change_metrics_format_for_payload(metrics):
@@ -8,13 +8,32 @@ class TrackName:
8
8
  def trace_decorator(self, func):
9
9
  @wraps(func)
10
10
  def wrapper(*args, **kwargs):
11
- file_name = self._get_file_name()
11
+ file_name = self._get_decorated_file_name()
12
12
  self.files.add(file_name)
13
13
 
14
14
  return func(*args, **kwargs)
15
15
  return wrapper
16
+
17
+ def trace_wrapper(self, func):
18
+ @wraps(func)
19
+ def wrapper(*args, **kwargs):
20
+ file_name = self._get_wrapped_file_name()
21
+ self.files.add(file_name)
22
+ return func(*args, **kwargs)
23
+ return wrapper
24
+
25
+ def _get_wrapped_file_name(self):
26
+ try:
27
+ from IPython import get_ipython
28
+ if 'IPKernelApp' in get_ipython().config:
29
+ return self._get_notebook_name()
30
+ except Exception:
31
+ pass
32
+
33
+ frame = inspect.stack()[4]
34
+ return frame.filename
16
35
 
17
- def _get_file_name(self):
36
+ def _get_decorated_file_name(self):
18
37
  # Check if running in a Jupyter notebook
19
38
  try:
20
39
  from IPython import get_ipython