ragaai-catalyst 2.1.5b21__py3-none-any.whl → 2.1.5b23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. ragaai_catalyst/__init__.py +3 -1
  2. ragaai_catalyst/dataset.py +49 -1
  3. ragaai_catalyst/redteaming.py +171 -0
  4. ragaai_catalyst/synthetic_data_generation.py +40 -7
  5. ragaai_catalyst/tracers/agentic_tracing/tracers/agent_tracer.py +57 -46
  6. ragaai_catalyst/tracers/agentic_tracing/tracers/base.py +218 -47
  7. ragaai_catalyst/tracers/agentic_tracing/tracers/custom_tracer.py +17 -7
  8. ragaai_catalyst/tracers/agentic_tracing/tracers/llm_tracer.py +327 -62
  9. ragaai_catalyst/tracers/agentic_tracing/tracers/main_tracer.py +0 -3
  10. ragaai_catalyst/tracers/agentic_tracing/tracers/tool_tracer.py +17 -6
  11. ragaai_catalyst/tracers/agentic_tracing/upload/upload_local_metric.py +72 -0
  12. ragaai_catalyst/tracers/agentic_tracing/upload/upload_trace_metric.py +32 -15
  13. ragaai_catalyst/tracers/agentic_tracing/utils/file_name_tracker.py +21 -2
  14. ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +33 -11
  15. ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +1204 -484
  16. ragaai_catalyst/tracers/agentic_tracing/utils/span_attributes.py +79 -10
  17. ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +0 -32
  18. ragaai_catalyst/tracers/agentic_tracing/utils/unique_decorator.py +3 -1
  19. ragaai_catalyst/tracers/agentic_tracing/utils/zip_list_of_unique_files.py +40 -21
  20. ragaai_catalyst/tracers/distributed.py +7 -3
  21. ragaai_catalyst/tracers/tracer.py +9 -9
  22. ragaai_catalyst/tracers/utils/langchain_tracer_extraction_logic.py +0 -1
  23. {ragaai_catalyst-2.1.5b21.dist-info → ragaai_catalyst-2.1.5b23.dist-info}/METADATA +37 -2
  24. {ragaai_catalyst-2.1.5b21.dist-info → ragaai_catalyst-2.1.5b23.dist-info}/RECORD +27 -25
  25. {ragaai_catalyst-2.1.5b21.dist-info → ragaai_catalyst-2.1.5b23.dist-info}/LICENSE +0 -0
  26. {ragaai_catalyst-2.1.5b21.dist-info → ragaai_catalyst-2.1.5b23.dist-info}/WHEEL +0 -0
  27. {ragaai_catalyst-2.1.5b21.dist-info → ragaai_catalyst-2.1.5b23.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from .dataset import Dataset
5
5
  from .prompt_manager import PromptManager
6
6
  from .evaluation import Evaluation
7
7
  from .synthetic_data_generation import SyntheticDataGeneration
8
+ from .redteaming import RedTeaming
8
9
  from .guardrails_manager import GuardrailsManager
9
10
  from .guard_executor import GuardExecutor
10
11
  from .tracers import Tracer, init_tracing, trace_agent, trace_llm, trace_tool, current_span, trace_custom
@@ -18,7 +19,8 @@ __all__ = [
18
19
  "Tracer",
19
20
  "PromptManager",
20
21
  "Evaluation",
21
- "SyntheticDataGeneration",
22
+ "SyntheticDataGeneration",
23
+ "RedTeaming",
22
24
  "GuardrailsManager",
23
25
  "GuardExecutor",
24
26
  "init_tracing",
@@ -1,5 +1,7 @@
1
1
  import os
2
+ import csv
2
3
  import json
4
+ import tempfile
3
5
  import requests
4
6
  from .utils import response_checker
5
7
  from typing import Union
@@ -653,4 +655,50 @@ class Dataset:
653
655
  return JOB_STATUS_FAILED
654
656
  except Exception as e:
655
657
  logger.error(f"An unexpected error occurred: {e}")
656
- return JOB_STATUS_FAILED
658
+ return JOB_STATUS_FAILED
659
+
660
+ def _jsonl_to_csv(self, jsonl_file, csv_file):
661
+ """Convert a JSONL file to a CSV file."""
662
+ with open(jsonl_file, 'r', encoding='utf-8') as infile:
663
+ data = [json.loads(line) for line in infile]
664
+
665
+ if not data:
666
+ print("Empty JSONL file.")
667
+ return
668
+
669
+ with open(csv_file, 'w', newline='', encoding='utf-8') as outfile:
670
+ writer = csv.DictWriter(outfile, fieldnames=data[0].keys())
671
+ writer.writeheader()
672
+ writer.writerows(data)
673
+
674
+ print(f"Converted {jsonl_file} to {csv_file}")
675
+
676
+ def create_from_jsonl(self, jsonl_path, dataset_name, schema_mapping):
677
+ tmp_csv_path = os.path.join(tempfile.gettempdir(), f"{dataset_name}.csv")
678
+ try:
679
+ self._jsonl_to_csv(jsonl_path, tmp_csv_path)
680
+ self.create_from_csv(tmp_csv_path, dataset_name, schema_mapping)
681
+ except (IOError, UnicodeError) as e:
682
+ logger.error(f"Error converting JSONL to CSV: {e}")
683
+ raise
684
+ finally:
685
+ if os.path.exists(tmp_csv_path):
686
+ try:
687
+ os.remove(tmp_csv_path)
688
+ except Exception as e:
689
+ logger.error(f"Error removing temporary CSV file: {e}")
690
+
691
+ def add_rows_from_jsonl(self, jsonl_path, dataset_name):
692
+ tmp_csv_path = os.path.join(tempfile.gettempdir(), f"{dataset_name}.csv")
693
+ try:
694
+ self._jsonl_to_csv(jsonl_path, tmp_csv_path)
695
+ self.add_rows(tmp_csv_path, dataset_name)
696
+ except (IOError, UnicodeError) as e:
697
+ logger.error(f"Error converting JSONL to CSV: {e}")
698
+ raise
699
+ finally:
700
+ if os.path.exists(tmp_csv_path):
701
+ try:
702
+ os.remove(tmp_csv_path)
703
+ except Exception as e:
704
+ logger.error(f"Error removing temporary CSV file: {e}")
@@ -0,0 +1,171 @@
1
+ import logging
2
+ import os
3
+ from typing import Callable, Optional
4
+
5
+ import giskard as scanner
6
+ import pandas as pd
7
+
8
+ logging.getLogger('giskard.core').disabled = True
9
+ logging.getLogger('giskard.scanner.logger').disabled = True
10
+ logging.getLogger('giskard.models.automodel').disabled = True
11
+ logging.getLogger('giskard.datasets.base').disabled = True
12
+ logging.getLogger('giskard.utils.logging_utils').disabled = True
13
+
14
+
15
+ class RedTeaming:
16
+
17
+ def __init__(self,
18
+ provider: Optional[str] = "openai",
19
+ model: Optional[str] = None,
20
+ api_key: Optional[str] = None,
21
+ api_base: Optional[str] = None,
22
+ api_version: Optional[str] = None):
23
+ self.provider = provider.lower()
24
+ self.model = model
25
+ if not self.provider:
26
+ raise ValueError("Model configuration must be provided with a valid provider and model.")
27
+ if self.provider == "openai":
28
+ if api_key is not None:
29
+ os.environ["OPENAI_API_KEY"] = api_key
30
+ if os.getenv("OPENAI_API_KEY") is None:
31
+ raise ValueError("API key must be provided for OpenAI.")
32
+ elif self.provider == "gemini":
33
+ if api_key is not None:
34
+ os.environ["GEMINI_API_KEY"] = api_key
35
+ if os.getenv("GEMINI_API_KEY") is None:
36
+ raise ValueError("API key must be provided for Gemini.")
37
+ elif self.provider == "azure":
38
+ if api_key is not None:
39
+ os.environ["AZURE_API_KEY"] = api_key
40
+ if api_base is not None:
41
+ os.environ["AZURE_API_BASE"] = api_base
42
+ if api_version is not None:
43
+ os.environ["AZURE_API_VERSION"] = api_version
44
+ if os.getenv("AZURE_API_KEY") is None:
45
+ raise ValueError("API key must be provided for Azure.")
46
+ if os.getenv("AZURE_API_BASE") is None:
47
+ raise ValueError("API base must be provided for Azure.")
48
+ if os.getenv("AZURE_API_VERSION") is None:
49
+ raise ValueError("API version must be provided for Azure.")
50
+ else:
51
+ raise ValueError(f"Provider is not recognized.")
52
+
53
+ def run_scan(
54
+ self,
55
+ model: Callable,
56
+ evaluators: Optional[list] = None,
57
+ save_report: bool = True
58
+ ) -> pd.DataFrame:
59
+ """
60
+ Runs red teaming on the provided model and returns a DataFrame of the results.
61
+
62
+ :param model: The model function provided by the user (can be sync or async).
63
+ :param evaluators: Optional list of scan metrics to run.
64
+ :param save_report: Boolean flag indicating whether to save the scan report as a CSV file.
65
+ :return: A DataFrame containing the scan report.
66
+ """
67
+ import asyncio
68
+ import inspect
69
+
70
+ self.set_scanning_model(self.provider, self.model)
71
+
72
+ supported_evaluators = self.get_supported_evaluators()
73
+ if evaluators:
74
+ if isinstance(evaluators, str):
75
+ evaluators = [evaluators]
76
+ invalid_evaluators = [evaluator for evaluator in evaluators if evaluator not in supported_evaluators]
77
+ if invalid_evaluators:
78
+ raise ValueError(f"Invalid evaluators: {invalid_evaluators}. "
79
+ f"Allowed evaluators: {supported_evaluators}.")
80
+
81
+ # Handle async model functions by wrapping them in a sync function
82
+ if inspect.iscoroutinefunction(model):
83
+ def sync_wrapper(*args, **kwargs):
84
+ try:
85
+ # Try to get the current event loop
86
+ loop = asyncio.get_event_loop()
87
+ except RuntimeError:
88
+ # If no event loop exists (e.g., in Jupyter), create a new one
89
+ loop = asyncio.new_event_loop()
90
+ asyncio.set_event_loop(loop)
91
+
92
+ try:
93
+ # Handle both IPython and regular Python environments
94
+ import nest_asyncio
95
+ nest_asyncio.apply()
96
+ except ImportError:
97
+ pass # nest_asyncio not available, continue without it
98
+
99
+ return loop.run_until_complete(model(*args, **kwargs))
100
+ wrapped_model = sync_wrapper
101
+ else:
102
+ wrapped_model = model
103
+
104
+ model_instance = scanner.Model(
105
+ model=wrapped_model,
106
+ model_type="text_generation",
107
+ name="RagaAI's Scan",
108
+ description="RagaAI's RedTeaming Scan",
109
+ feature_names=["question"],
110
+ )
111
+
112
+ try:
113
+ report = scanner.scan(model_instance, only=evaluators, raise_exceptions=True) if evaluators \
114
+ else scanner.scan(model_instance, raise_exceptions=True)
115
+ except Exception as e:
116
+ raise RuntimeError(f"Error occurred during model scan: {str(e)}")
117
+
118
+ report_df = report.to_dataframe()
119
+
120
+ if save_report:
121
+ report_df.to_csv("raga-ai_red-teaming_scan.csv", index=False)
122
+
123
+ return report_df
124
+
125
+ def get_supported_evaluators(self):
126
+ """Contains tags corresponding to the 'llm' and 'robustness' directories in the giskard > scanner library"""
127
+ return {'control_chars_injection',
128
+ 'discrimination',
129
+ 'ethical_bias',
130
+ 'ethics',
131
+ 'faithfulness',
132
+ 'generative',
133
+ 'hallucination',
134
+ 'harmfulness',
135
+ 'implausible_output',
136
+ 'information_disclosure',
137
+ 'jailbreak',
138
+ 'llm',
139
+ 'llm_harmful_content',
140
+ 'llm_stereotypes_detector',
141
+ 'misinformation',
142
+ 'output_formatting',
143
+ 'prompt_injection',
144
+ 'robustness',
145
+ 'stereotypes',
146
+ 'sycophancy',
147
+ 'text_generation',
148
+ 'text_perturbation'}
149
+
150
+ def set_scanning_model(self, provider, model=None):
151
+ """
152
+ Sets the LLM model for Giskard based on the provider.
153
+
154
+ :param provider: The LLM provider (e.g., "openai", "gemini", "azure").
155
+ :param model: The specific model name to use (optional).
156
+ :raises ValueError: If the provider is "azure" and no model is provided.
157
+ """
158
+ default_models = {
159
+ "openai": "gpt-4o",
160
+ "gemini": "gemini-1.5-pro"
161
+ }
162
+
163
+ if provider == "azure" and model is None:
164
+ raise ValueError("Model must be provided for Azure.")
165
+
166
+ selected_model = model if model is not None else default_models.get(provider)
167
+
168
+ if selected_model is None:
169
+ raise ValueError(f"Unsupported provider: {provider}")
170
+
171
+ scanner.llm.set_llm_model(selected_model)
@@ -8,7 +8,9 @@ import markdown
8
8
  import pandas as pd
9
9
  import json
10
10
  from litellm import completion
11
+ import litellm
11
12
  from tqdm import tqdm
13
+ import tiktoken
12
14
  # import internal_api_completion
13
15
  # import proxy_call
14
16
  from .internal_api_completion import api_completion as internal_api_completion
@@ -48,13 +50,18 @@ class SyntheticDataGeneration:
48
50
  Raises:
49
51
  ValueError: If an invalid provider is specified or API key is missing.
50
52
  """
53
+ text_validity = self.validate_input(text)
54
+ if text_validity:
55
+ raise ValueError(text_validity)
56
+
51
57
  BATCH_SIZE = 5 # Optimal batch size for maintaining response quality
52
58
  provider = model_config.get("provider")
53
59
  model = model_config.get("model")
54
60
  api_base = model_config.get("api_base")
61
+ api_version = model_config.get("api_version")
55
62
 
56
63
  # Initialize the appropriate client based on provider
57
- self._initialize_client(provider, api_key, api_base, internal_llm_proxy=kwargs.get("internal_llm_proxy", None))
64
+ self._initialize_client(provider, api_key, api_base, api_version, internal_llm_proxy=kwargs.get("internal_llm_proxy", None))
58
65
 
59
66
  # Initialize progress bar
60
67
  pbar = tqdm(total=n, desc="Generating QA pairs")
@@ -88,7 +95,7 @@ class SyntheticDataGeneration:
88
95
  pbar.update(len(batch_df))
89
96
 
90
97
  except Exception as e:
91
- print(f"Batch generation failed.")
98
+ print(f"Batch generation failed:{str(e)}")
92
99
 
93
100
  if any(error in str(e) for error in FAILURE_CASES):
94
101
  raise Exception(f"{e}")
@@ -139,7 +146,7 @@ class SyntheticDataGeneration:
139
146
 
140
147
  return final_df
141
148
 
142
- def _initialize_client(self, provider, api_key, api_base=None, internal_llm_proxy=None):
149
+ def _initialize_client(self, provider, api_key, api_base=None, api_version=None, internal_llm_proxy=None):
143
150
  """Initialize the appropriate client based on provider."""
144
151
  if not provider:
145
152
  raise ValueError("Model configuration must be provided with a valid provider and model.")
@@ -158,7 +165,17 @@ class SyntheticDataGeneration:
158
165
  if api_key is None and os.getenv("OPENAI_API_KEY") is None and internal_llm_proxy is None:
159
166
  raise ValueError("API key must be provided for OpenAI.")
160
167
  openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
161
-
168
+
169
+ elif provider == "azure":
170
+ if api_key is None and os.getenv("AZURE_API_KEY") is None and internal_llm_proxy is None:
171
+ raise ValueError("API key must be provided for Azure.")
172
+ litellm.api_key = api_key or os.getenv("AZURE_API_KEY")
173
+ if api_base is None and os.getenv("AZURE_API_BASE") is None and internal_llm_proxy is None:
174
+ raise ValueError("API Base must be provided for Azure.")
175
+ litellm.api_base = api_base or os.getenv("AZURE_API_BASE")
176
+ if api_version is None and os.getenv("AZURE_API_VERSION") is None and internal_llm_proxy is None:
177
+ raise ValueError("API version must be provided for Azure.")
178
+ litellm.api_version = api_version or os.getenv("AZURE_API_VERSION")
162
179
  else:
163
180
  raise ValueError(f"Provider is not recognized.")
164
181
 
@@ -189,7 +206,15 @@ class SyntheticDataGeneration:
189
206
  kwargs=kwargs
190
207
  )
191
208
 
209
+ def validate_input(self,text):
192
210
 
211
+ if not text.strip():
212
+ return 'Empty Text provided for qna generation. Please provide valid text'
213
+ encoding = tiktoken.encoding_for_model("gpt-4")
214
+ tokens = encoding.encode(text)
215
+ if len(tokens)<5:
216
+ return 'Very Small Text provided for qna generation. Please provide longer text'
217
+ return False
193
218
 
194
219
 
195
220
  def _get_system_message(self, question_type, n):
@@ -274,10 +299,14 @@ class SyntheticDataGeneration:
274
299
  # Add optional parameters if they exist in model_config
275
300
  if "api_base" in model_config:
276
301
  completion_params["api_base"] = model_config["api_base"]
302
+ if "api_version" in model_config:
303
+ completion_params["api_version"] = model_config["api_version"]
277
304
  if "max_tokens" in model_config:
278
305
  completion_params["max_tokens"] = model_config["max_tokens"]
279
306
  if "temperature" in model_config:
280
307
  completion_params["temperature"] = model_config["temperature"]
308
+ if 'provider' in model_config:
309
+ completion_params['model'] = f'{model_config["provider"]}/{model_config["model"]}'
281
310
 
282
311
  # Make the API call using LiteLLM
283
312
  try:
@@ -318,9 +347,13 @@ class SyntheticDataGeneration:
318
347
  list_start_index = data.find('[') # Find the index of the first '['
319
348
  substring_data = data[list_start_index:] if list_start_index != -1 else data # Slice from the list start
320
349
  data = substring_data
321
-
350
+ elif provider == "azure":
351
+ data = response.choices[0].message.content.replace('\n', '')
352
+ list_start_index = data.find('[') # Find the index of the first '['
353
+ substring_data = data[list_start_index:] if list_start_index != -1 else data # Slice from the list start
354
+ data = substring_data
322
355
  else:
323
- raise ValueError("Invalid provider. Choose 'groq', 'gemini', or 'openai'.")
356
+ raise ValueError("Invalid provider. Choose 'groq', 'gemini', 'azure' or 'openai'.")
324
357
  try:
325
358
  json_data = json.loads(data)
326
359
  return pd.DataFrame(json_data)
@@ -442,7 +475,7 @@ class SyntheticDataGeneration:
442
475
  Returns:
443
476
  list: A list of supported AI providers.
444
477
  """
445
- return ['gemini', 'openai']
478
+ return ['gemini', 'openai','azure']
446
479
 
447
480
  # Usage:
448
481
  # from synthetic_data_generation import SyntheticDataGeneration
@@ -48,15 +48,15 @@ class AgentTracerMixin:
48
48
  self.auto_instrument_network = False
49
49
 
50
50
  def trace_agent(
51
- self,
52
- name: str,
53
- agent_type: str = None,
54
- version: str = None,
55
- capabilities: List[str] = None,
56
- tags: List[str] = [],
57
- metadata: Dict[str, Any] = {},
58
- metrics: List[Dict[str, Any]] = [],
59
- feedback: Optional[Any] = None,
51
+ self,
52
+ name: str,
53
+ agent_type: str = None,
54
+ version: str = None,
55
+ capabilities: List[str] = None,
56
+ tags: List[str] = [],
57
+ metadata: Dict[str, Any] = {},
58
+ metrics: List[Dict[str, Any]] = [],
59
+ feedback: Optional[Any] = None,
60
60
  ):
61
61
  if name not in self.span_attributes_dict:
62
62
  self.span_attributes_dict[name] = SpanAttributes(name)
@@ -101,7 +101,10 @@ class AgentTracerMixin:
101
101
  original_init = target.__init__
102
102
 
103
103
  def wrapped_init(self, *args, **kwargs):
104
- self.gt = kwargs.get("gt", None) if kwargs else None
104
+ gt = kwargs.get("gt") if kwargs else None
105
+ if gt is not None:
106
+ span = self.span(name)
107
+ span.add_gt(gt)
105
108
  # Set agent context before initializing
106
109
  component_id = str(uuid.uuid4())
107
110
  hash_id = top_level_hash_id
@@ -159,7 +162,10 @@ class AgentTracerMixin:
159
162
  @self.file_tracker.trace_decorator
160
163
  @functools.wraps(method)
161
164
  def wrapped_method(self, *args, **kwargs):
162
- self.gt = kwargs.get("gt", None) if kwargs else None
165
+ gt = kwargs.get("gt") if kwargs else None
166
+ if gt is not None:
167
+ span = tracer.span(name)
168
+ span.add_gt(gt)
163
169
  # Set this agent as current during method execution
164
170
  token = tracer.current_agent_id.set(
165
171
  self._agent_component_id
@@ -193,8 +199,8 @@ class AgentTracerMixin:
193
199
  children = tracer.agent_children.get()
194
200
  if children:
195
201
  if (
196
- "children"
197
- not in component["data"]
202
+ "children"
203
+ not in component["data"]
198
204
  ):
199
205
  component["data"][
200
206
  "children"
@@ -247,6 +253,7 @@ class AgentTracerMixin:
247
253
  agent_type,
248
254
  version,
249
255
  capabilities,
256
+ top_level_hash_id,
250
257
  *args,
251
258
  **kwargs,
252
259
  )
@@ -256,10 +263,9 @@ class AgentTracerMixin:
256
263
  return decorator
257
264
 
258
265
  def _trace_sync_agent_execution(
259
- self, func, name, agent_type, version, capabilities, *args, **kwargs
266
+ self, func, name, agent_type, version, capabilities, top_level_hash_id, *args, **kwargs
260
267
  ):
261
- # Generate a unique hash_id for this execution context
262
- hash_id = str(uuid.uuid4())
268
+ hash_id = top_level_hash_id
263
269
 
264
270
  """Synchronous version of agent tracing"""
265
271
  if not self.is_active:
@@ -276,6 +282,9 @@ class AgentTracerMixin:
276
282
 
277
283
  # Extract ground truth if present
278
284
  ground_truth = kwargs.pop("gt", None) if kwargs else None
285
+ if ground_truth is not None:
286
+ span = self.span(name)
287
+ span.add_gt(ground_truth)
279
288
 
280
289
  # Get parent agent ID if exists
281
290
  parent_agent_id = self.current_agent_id.get()
@@ -293,7 +302,7 @@ class AgentTracerMixin:
293
302
 
294
303
  try:
295
304
  # Execute the agent
296
- result = func(*args, **kwargs)
305
+ result = self.file_tracker.trace_wrapper(func)(*args, **kwargs)
297
306
 
298
307
  # Calculate resource usage
299
308
  end_memory = psutil.Process().memory_info().rss
@@ -320,9 +329,6 @@ class AgentTracerMixin:
320
329
  children=children,
321
330
  parent_id=parent_agent_id,
322
331
  )
323
- # Add ground truth to component data if present
324
- if ground_truth is not None:
325
- agent_component["data"]["gt"] = ground_truth
326
332
 
327
333
  # Add this component as a child to parent's children list
328
334
  parent_children.append(agent_component)
@@ -384,7 +390,7 @@ class AgentTracerMixin:
384
390
  self.agent_children.reset(children_token)
385
391
 
386
392
  async def _trace_agent_execution(
387
- self, func, name, agent_type, version, capabilities, hash_id, *args, **kwargs
393
+ self, func, name, agent_type, version, capabilities, hash_id, *args, **kwargs
388
394
  ):
389
395
  """Asynchronous version of agent tracing"""
390
396
  if not self.is_active:
@@ -399,6 +405,9 @@ class AgentTracerMixin:
399
405
 
400
406
  # Extract ground truth if present
401
407
  ground_truth = kwargs.pop("gt", None) if kwargs else None
408
+ if ground_truth is not None:
409
+ span = self.span(name)
410
+ span.add_gt(ground_truth)
402
411
 
403
412
  # Get parent agent ID if exists
404
413
  parent_agent_id = self.current_agent_id.get()
@@ -414,7 +423,7 @@ class AgentTracerMixin:
414
423
 
415
424
  try:
416
425
  # Execute the agent
417
- result = await func(*args, **kwargs)
426
+ result = await self.file_tracker.trace_wrapper(func)(*args, **kwargs)
418
427
 
419
428
  # Calculate resource usage
420
429
  end_memory = psutil.Process().memory_info().rss
@@ -441,10 +450,6 @@ class AgentTracerMixin:
441
450
  parent_id=parent_agent_id,
442
451
  )
443
452
 
444
- # Add ground truth to component data if present
445
- if ground_truth is not None:
446
- agent_component["data"]["gt"] = ground_truth
447
-
448
453
  # Add this component as a child to parent's children list
449
454
  parent_children.append(agent_component)
450
455
  self.agent_children.set(parent_children)
@@ -517,7 +522,7 @@ class AgentTracerMixin:
517
522
  for interaction in self.component_user_interaction.get(kwargs["component_id"], []):
518
523
  if interaction["interaction_type"] in ["input", "output"]:
519
524
  input_output_interactions.append(interaction)
520
- interactions.extend(input_output_interactions)
525
+ interactions.extend(input_output_interactions)
521
526
  if self.auto_instrument_file_io:
522
527
  file_io_interactions = []
523
528
  for interaction in self.component_user_interaction.get(kwargs["component_id"], []):
@@ -546,9 +551,10 @@ class AgentTracerMixin:
546
551
  counter = sum(1 for x in self.visited_metrics if x.startswith(base_metric_name))
547
552
  metric_name = f'{base_metric_name}_{counter}' if counter > 0 else base_metric_name
548
553
  self.visited_metrics.append(metric_name)
549
- metric["name"] = metric_name
554
+ metric["name"] = metric_name
550
555
  metrics.append(metric)
551
556
 
557
+ # TODO agent_trace execute metric
552
558
  component = {
553
559
  "id": kwargs["component_id"],
554
560
  "hash_id": kwargs["hash_id"],
@@ -576,8 +582,13 @@ class AgentTracerMixin:
576
582
  "interactions": interactions,
577
583
  }
578
584
 
579
- if self.gt:
580
- component["data"]["gt"] = self.gt
585
+ if name in self.span_attributes_dict:
586
+ span_gt = self.span_attributes_dict[name].gt
587
+ if span_gt is not None:
588
+ component["data"]["gt"] = span_gt
589
+ span_context = self.span_attributes_dict[name].context
590
+ if span_context:
591
+ component["data"]["context"] = span_context
581
592
 
582
593
  # Reset the SpanAttributes context variable
583
594
  self.span_attributes_dict[kwargs["name"]] = SpanAttributes(kwargs["name"])
@@ -599,22 +610,22 @@ class AgentTracerMixin:
599
610
  self.component_network_calls.set(component_network_calls)
600
611
 
601
612
  def _sanitize_input(self, args: tuple, kwargs: dict) -> dict:
602
- """Sanitize and format input data, including handling of nested lists and dictionaries."""
603
-
604
- def sanitize_value(value):
605
- if isinstance(value, (int, float, bool, str)):
606
- return value
607
- elif isinstance(value, list):
608
- return [sanitize_value(item) for item in value]
609
- elif isinstance(value, dict):
610
- return {key: sanitize_value(val) for key, val in value.items()}
611
- else:
612
- return str(value) # Convert non-standard types to string
613
+ """Sanitize and format input data, including handling of nested lists and dictionaries."""
614
+
615
+ def sanitize_value(value):
616
+ if isinstance(value, (int, float, bool, str)):
617
+ return value
618
+ elif isinstance(value, list):
619
+ return [sanitize_value(item) for item in value]
620
+ elif isinstance(value, dict):
621
+ return {key: sanitize_value(val) for key, val in value.items()}
622
+ else:
623
+ return str(value) # Convert non-standard types to string
613
624
 
614
- return {
615
- "args": [sanitize_value(arg) for arg in args],
616
- "kwargs": {key: sanitize_value(val) for key, val in kwargs.items()},
617
- }
625
+ return {
626
+ "args": [sanitize_value(arg) for arg in args],
627
+ "kwargs": {key: sanitize_value(val) for key, val in kwargs.items()},
628
+ }
618
629
 
619
630
  def _sanitize_output(self, output: Any) -> Any:
620
631
  """Sanitize and format output data"""
@@ -630,6 +641,6 @@ class AgentTracerMixin:
630
641
 
631
642
  def instrument_network_calls(self):
632
643
  self.auto_instrument_network = True
633
-
644
+
634
645
  def instrument_file_io_calls(self):
635
646
  self.auto_instrument_file_io = True