ragaai-catalyst 2.1.5b21__py3-none-any.whl → 2.1.5b23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragaai_catalyst/__init__.py +3 -1
- ragaai_catalyst/dataset.py +49 -1
- ragaai_catalyst/redteaming.py +171 -0
- ragaai_catalyst/synthetic_data_generation.py +40 -7
- ragaai_catalyst/tracers/agentic_tracing/tracers/agent_tracer.py +57 -46
- ragaai_catalyst/tracers/agentic_tracing/tracers/base.py +218 -47
- ragaai_catalyst/tracers/agentic_tracing/tracers/custom_tracer.py +17 -7
- ragaai_catalyst/tracers/agentic_tracing/tracers/llm_tracer.py +327 -62
- ragaai_catalyst/tracers/agentic_tracing/tracers/main_tracer.py +0 -3
- ragaai_catalyst/tracers/agentic_tracing/tracers/tool_tracer.py +17 -6
- ragaai_catalyst/tracers/agentic_tracing/upload/upload_local_metric.py +72 -0
- ragaai_catalyst/tracers/agentic_tracing/upload/upload_trace_metric.py +32 -15
- ragaai_catalyst/tracers/agentic_tracing/utils/file_name_tracker.py +21 -2
- ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +33 -11
- ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +1204 -484
- ragaai_catalyst/tracers/agentic_tracing/utils/span_attributes.py +79 -10
- ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +0 -32
- ragaai_catalyst/tracers/agentic_tracing/utils/unique_decorator.py +3 -1
- ragaai_catalyst/tracers/agentic_tracing/utils/zip_list_of_unique_files.py +40 -21
- ragaai_catalyst/tracers/distributed.py +7 -3
- ragaai_catalyst/tracers/tracer.py +9 -9
- ragaai_catalyst/tracers/utils/langchain_tracer_extraction_logic.py +0 -1
- {ragaai_catalyst-2.1.5b21.dist-info → ragaai_catalyst-2.1.5b23.dist-info}/METADATA +37 -2
- {ragaai_catalyst-2.1.5b21.dist-info → ragaai_catalyst-2.1.5b23.dist-info}/RECORD +27 -25
- {ragaai_catalyst-2.1.5b21.dist-info → ragaai_catalyst-2.1.5b23.dist-info}/LICENSE +0 -0
- {ragaai_catalyst-2.1.5b21.dist-info → ragaai_catalyst-2.1.5b23.dist-info}/WHEEL +0 -0
- {ragaai_catalyst-2.1.5b21.dist-info → ragaai_catalyst-2.1.5b23.dist-info}/top_level.txt +0 -0
ragaai_catalyst/__init__.py
CHANGED
@@ -5,6 +5,7 @@ from .dataset import Dataset
|
|
5
5
|
from .prompt_manager import PromptManager
|
6
6
|
from .evaluation import Evaluation
|
7
7
|
from .synthetic_data_generation import SyntheticDataGeneration
|
8
|
+
from .redteaming import RedTeaming
|
8
9
|
from .guardrails_manager import GuardrailsManager
|
9
10
|
from .guard_executor import GuardExecutor
|
10
11
|
from .tracers import Tracer, init_tracing, trace_agent, trace_llm, trace_tool, current_span, trace_custom
|
@@ -18,7 +19,8 @@ __all__ = [
|
|
18
19
|
"Tracer",
|
19
20
|
"PromptManager",
|
20
21
|
"Evaluation",
|
21
|
-
"SyntheticDataGeneration",
|
22
|
+
"SyntheticDataGeneration",
|
23
|
+
"RedTeaming",
|
22
24
|
"GuardrailsManager",
|
23
25
|
"GuardExecutor",
|
24
26
|
"init_tracing",
|
ragaai_catalyst/dataset.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
import os
|
2
|
+
import csv
|
2
3
|
import json
|
4
|
+
import tempfile
|
3
5
|
import requests
|
4
6
|
from .utils import response_checker
|
5
7
|
from typing import Union
|
@@ -653,4 +655,50 @@ class Dataset:
|
|
653
655
|
return JOB_STATUS_FAILED
|
654
656
|
except Exception as e:
|
655
657
|
logger.error(f"An unexpected error occurred: {e}")
|
656
|
-
return JOB_STATUS_FAILED
|
658
|
+
return JOB_STATUS_FAILED
|
659
|
+
|
660
|
+
def _jsonl_to_csv(self, jsonl_file, csv_file):
|
661
|
+
"""Convert a JSONL file to a CSV file."""
|
662
|
+
with open(jsonl_file, 'r', encoding='utf-8') as infile:
|
663
|
+
data = [json.loads(line) for line in infile]
|
664
|
+
|
665
|
+
if not data:
|
666
|
+
print("Empty JSONL file.")
|
667
|
+
return
|
668
|
+
|
669
|
+
with open(csv_file, 'w', newline='', encoding='utf-8') as outfile:
|
670
|
+
writer = csv.DictWriter(outfile, fieldnames=data[0].keys())
|
671
|
+
writer.writeheader()
|
672
|
+
writer.writerows(data)
|
673
|
+
|
674
|
+
print(f"Converted {jsonl_file} to {csv_file}")
|
675
|
+
|
676
|
+
def create_from_jsonl(self, jsonl_path, dataset_name, schema_mapping):
|
677
|
+
tmp_csv_path = os.path.join(tempfile.gettempdir(), f"{dataset_name}.csv")
|
678
|
+
try:
|
679
|
+
self._jsonl_to_csv(jsonl_path, tmp_csv_path)
|
680
|
+
self.create_from_csv(tmp_csv_path, dataset_name, schema_mapping)
|
681
|
+
except (IOError, UnicodeError) as e:
|
682
|
+
logger.error(f"Error converting JSONL to CSV: {e}")
|
683
|
+
raise
|
684
|
+
finally:
|
685
|
+
if os.path.exists(tmp_csv_path):
|
686
|
+
try:
|
687
|
+
os.remove(tmp_csv_path)
|
688
|
+
except Exception as e:
|
689
|
+
logger.error(f"Error removing temporary CSV file: {e}")
|
690
|
+
|
691
|
+
def add_rows_from_jsonl(self, jsonl_path, dataset_name):
|
692
|
+
tmp_csv_path = os.path.join(tempfile.gettempdir(), f"{dataset_name}.csv")
|
693
|
+
try:
|
694
|
+
self._jsonl_to_csv(jsonl_path, tmp_csv_path)
|
695
|
+
self.add_rows(tmp_csv_path, dataset_name)
|
696
|
+
except (IOError, UnicodeError) as e:
|
697
|
+
logger.error(f"Error converting JSONL to CSV: {e}")
|
698
|
+
raise
|
699
|
+
finally:
|
700
|
+
if os.path.exists(tmp_csv_path):
|
701
|
+
try:
|
702
|
+
os.remove(tmp_csv_path)
|
703
|
+
except Exception as e:
|
704
|
+
logger.error(f"Error removing temporary CSV file: {e}")
|
@@ -0,0 +1,171 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from typing import Callable, Optional
|
4
|
+
|
5
|
+
import giskard as scanner
|
6
|
+
import pandas as pd
|
7
|
+
|
8
|
+
logging.getLogger('giskard.core').disabled = True
|
9
|
+
logging.getLogger('giskard.scanner.logger').disabled = True
|
10
|
+
logging.getLogger('giskard.models.automodel').disabled = True
|
11
|
+
logging.getLogger('giskard.datasets.base').disabled = True
|
12
|
+
logging.getLogger('giskard.utils.logging_utils').disabled = True
|
13
|
+
|
14
|
+
|
15
|
+
class RedTeaming:
|
16
|
+
|
17
|
+
def __init__(self,
|
18
|
+
provider: Optional[str] = "openai",
|
19
|
+
model: Optional[str] = None,
|
20
|
+
api_key: Optional[str] = None,
|
21
|
+
api_base: Optional[str] = None,
|
22
|
+
api_version: Optional[str] = None):
|
23
|
+
self.provider = provider.lower()
|
24
|
+
self.model = model
|
25
|
+
if not self.provider:
|
26
|
+
raise ValueError("Model configuration must be provided with a valid provider and model.")
|
27
|
+
if self.provider == "openai":
|
28
|
+
if api_key is not None:
|
29
|
+
os.environ["OPENAI_API_KEY"] = api_key
|
30
|
+
if os.getenv("OPENAI_API_KEY") is None:
|
31
|
+
raise ValueError("API key must be provided for OpenAI.")
|
32
|
+
elif self.provider == "gemini":
|
33
|
+
if api_key is not None:
|
34
|
+
os.environ["GEMINI_API_KEY"] = api_key
|
35
|
+
if os.getenv("GEMINI_API_KEY") is None:
|
36
|
+
raise ValueError("API key must be provided for Gemini.")
|
37
|
+
elif self.provider == "azure":
|
38
|
+
if api_key is not None:
|
39
|
+
os.environ["AZURE_API_KEY"] = api_key
|
40
|
+
if api_base is not None:
|
41
|
+
os.environ["AZURE_API_BASE"] = api_base
|
42
|
+
if api_version is not None:
|
43
|
+
os.environ["AZURE_API_VERSION"] = api_version
|
44
|
+
if os.getenv("AZURE_API_KEY") is None:
|
45
|
+
raise ValueError("API key must be provided for Azure.")
|
46
|
+
if os.getenv("AZURE_API_BASE") is None:
|
47
|
+
raise ValueError("API base must be provided for Azure.")
|
48
|
+
if os.getenv("AZURE_API_VERSION") is None:
|
49
|
+
raise ValueError("API version must be provided for Azure.")
|
50
|
+
else:
|
51
|
+
raise ValueError(f"Provider is not recognized.")
|
52
|
+
|
53
|
+
def run_scan(
|
54
|
+
self,
|
55
|
+
model: Callable,
|
56
|
+
evaluators: Optional[list] = None,
|
57
|
+
save_report: bool = True
|
58
|
+
) -> pd.DataFrame:
|
59
|
+
"""
|
60
|
+
Runs red teaming on the provided model and returns a DataFrame of the results.
|
61
|
+
|
62
|
+
:param model: The model function provided by the user (can be sync or async).
|
63
|
+
:param evaluators: Optional list of scan metrics to run.
|
64
|
+
:param save_report: Boolean flag indicating whether to save the scan report as a CSV file.
|
65
|
+
:return: A DataFrame containing the scan report.
|
66
|
+
"""
|
67
|
+
import asyncio
|
68
|
+
import inspect
|
69
|
+
|
70
|
+
self.set_scanning_model(self.provider, self.model)
|
71
|
+
|
72
|
+
supported_evaluators = self.get_supported_evaluators()
|
73
|
+
if evaluators:
|
74
|
+
if isinstance(evaluators, str):
|
75
|
+
evaluators = [evaluators]
|
76
|
+
invalid_evaluators = [evaluator for evaluator in evaluators if evaluator not in supported_evaluators]
|
77
|
+
if invalid_evaluators:
|
78
|
+
raise ValueError(f"Invalid evaluators: {invalid_evaluators}. "
|
79
|
+
f"Allowed evaluators: {supported_evaluators}.")
|
80
|
+
|
81
|
+
# Handle async model functions by wrapping them in a sync function
|
82
|
+
if inspect.iscoroutinefunction(model):
|
83
|
+
def sync_wrapper(*args, **kwargs):
|
84
|
+
try:
|
85
|
+
# Try to get the current event loop
|
86
|
+
loop = asyncio.get_event_loop()
|
87
|
+
except RuntimeError:
|
88
|
+
# If no event loop exists (e.g., in Jupyter), create a new one
|
89
|
+
loop = asyncio.new_event_loop()
|
90
|
+
asyncio.set_event_loop(loop)
|
91
|
+
|
92
|
+
try:
|
93
|
+
# Handle both IPython and regular Python environments
|
94
|
+
import nest_asyncio
|
95
|
+
nest_asyncio.apply()
|
96
|
+
except ImportError:
|
97
|
+
pass # nest_asyncio not available, continue without it
|
98
|
+
|
99
|
+
return loop.run_until_complete(model(*args, **kwargs))
|
100
|
+
wrapped_model = sync_wrapper
|
101
|
+
else:
|
102
|
+
wrapped_model = model
|
103
|
+
|
104
|
+
model_instance = scanner.Model(
|
105
|
+
model=wrapped_model,
|
106
|
+
model_type="text_generation",
|
107
|
+
name="RagaAI's Scan",
|
108
|
+
description="RagaAI's RedTeaming Scan",
|
109
|
+
feature_names=["question"],
|
110
|
+
)
|
111
|
+
|
112
|
+
try:
|
113
|
+
report = scanner.scan(model_instance, only=evaluators, raise_exceptions=True) if evaluators \
|
114
|
+
else scanner.scan(model_instance, raise_exceptions=True)
|
115
|
+
except Exception as e:
|
116
|
+
raise RuntimeError(f"Error occurred during model scan: {str(e)}")
|
117
|
+
|
118
|
+
report_df = report.to_dataframe()
|
119
|
+
|
120
|
+
if save_report:
|
121
|
+
report_df.to_csv("raga-ai_red-teaming_scan.csv", index=False)
|
122
|
+
|
123
|
+
return report_df
|
124
|
+
|
125
|
+
def get_supported_evaluators(self):
|
126
|
+
"""Contains tags corresponding to the 'llm' and 'robustness' directories in the giskard > scanner library"""
|
127
|
+
return {'control_chars_injection',
|
128
|
+
'discrimination',
|
129
|
+
'ethical_bias',
|
130
|
+
'ethics',
|
131
|
+
'faithfulness',
|
132
|
+
'generative',
|
133
|
+
'hallucination',
|
134
|
+
'harmfulness',
|
135
|
+
'implausible_output',
|
136
|
+
'information_disclosure',
|
137
|
+
'jailbreak',
|
138
|
+
'llm',
|
139
|
+
'llm_harmful_content',
|
140
|
+
'llm_stereotypes_detector',
|
141
|
+
'misinformation',
|
142
|
+
'output_formatting',
|
143
|
+
'prompt_injection',
|
144
|
+
'robustness',
|
145
|
+
'stereotypes',
|
146
|
+
'sycophancy',
|
147
|
+
'text_generation',
|
148
|
+
'text_perturbation'}
|
149
|
+
|
150
|
+
def set_scanning_model(self, provider, model=None):
|
151
|
+
"""
|
152
|
+
Sets the LLM model for Giskard based on the provider.
|
153
|
+
|
154
|
+
:param provider: The LLM provider (e.g., "openai", "gemini", "azure").
|
155
|
+
:param model: The specific model name to use (optional).
|
156
|
+
:raises ValueError: If the provider is "azure" and no model is provided.
|
157
|
+
"""
|
158
|
+
default_models = {
|
159
|
+
"openai": "gpt-4o",
|
160
|
+
"gemini": "gemini-1.5-pro"
|
161
|
+
}
|
162
|
+
|
163
|
+
if provider == "azure" and model is None:
|
164
|
+
raise ValueError("Model must be provided for Azure.")
|
165
|
+
|
166
|
+
selected_model = model if model is not None else default_models.get(provider)
|
167
|
+
|
168
|
+
if selected_model is None:
|
169
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
170
|
+
|
171
|
+
scanner.llm.set_llm_model(selected_model)
|
@@ -8,7 +8,9 @@ import markdown
|
|
8
8
|
import pandas as pd
|
9
9
|
import json
|
10
10
|
from litellm import completion
|
11
|
+
import litellm
|
11
12
|
from tqdm import tqdm
|
13
|
+
import tiktoken
|
12
14
|
# import internal_api_completion
|
13
15
|
# import proxy_call
|
14
16
|
from .internal_api_completion import api_completion as internal_api_completion
|
@@ -48,13 +50,18 @@ class SyntheticDataGeneration:
|
|
48
50
|
Raises:
|
49
51
|
ValueError: If an invalid provider is specified or API key is missing.
|
50
52
|
"""
|
53
|
+
text_validity = self.validate_input(text)
|
54
|
+
if text_validity:
|
55
|
+
raise ValueError(text_validity)
|
56
|
+
|
51
57
|
BATCH_SIZE = 5 # Optimal batch size for maintaining response quality
|
52
58
|
provider = model_config.get("provider")
|
53
59
|
model = model_config.get("model")
|
54
60
|
api_base = model_config.get("api_base")
|
61
|
+
api_version = model_config.get("api_version")
|
55
62
|
|
56
63
|
# Initialize the appropriate client based on provider
|
57
|
-
self._initialize_client(provider, api_key, api_base, internal_llm_proxy=kwargs.get("internal_llm_proxy", None))
|
64
|
+
self._initialize_client(provider, api_key, api_base, api_version, internal_llm_proxy=kwargs.get("internal_llm_proxy", None))
|
58
65
|
|
59
66
|
# Initialize progress bar
|
60
67
|
pbar = tqdm(total=n, desc="Generating QA pairs")
|
@@ -88,7 +95,7 @@ class SyntheticDataGeneration:
|
|
88
95
|
pbar.update(len(batch_df))
|
89
96
|
|
90
97
|
except Exception as e:
|
91
|
-
print(f"Batch generation failed
|
98
|
+
print(f"Batch generation failed:{str(e)}")
|
92
99
|
|
93
100
|
if any(error in str(e) for error in FAILURE_CASES):
|
94
101
|
raise Exception(f"{e}")
|
@@ -139,7 +146,7 @@ class SyntheticDataGeneration:
|
|
139
146
|
|
140
147
|
return final_df
|
141
148
|
|
142
|
-
def _initialize_client(self, provider, api_key, api_base=None, internal_llm_proxy=None):
|
149
|
+
def _initialize_client(self, provider, api_key, api_base=None, api_version=None, internal_llm_proxy=None):
|
143
150
|
"""Initialize the appropriate client based on provider."""
|
144
151
|
if not provider:
|
145
152
|
raise ValueError("Model configuration must be provided with a valid provider and model.")
|
@@ -158,7 +165,17 @@ class SyntheticDataGeneration:
|
|
158
165
|
if api_key is None and os.getenv("OPENAI_API_KEY") is None and internal_llm_proxy is None:
|
159
166
|
raise ValueError("API key must be provided for OpenAI.")
|
160
167
|
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
161
|
-
|
168
|
+
|
169
|
+
elif provider == "azure":
|
170
|
+
if api_key is None and os.getenv("AZURE_API_KEY") is None and internal_llm_proxy is None:
|
171
|
+
raise ValueError("API key must be provided for Azure.")
|
172
|
+
litellm.api_key = api_key or os.getenv("AZURE_API_KEY")
|
173
|
+
if api_base is None and os.getenv("AZURE_API_BASE") is None and internal_llm_proxy is None:
|
174
|
+
raise ValueError("API Base must be provided for Azure.")
|
175
|
+
litellm.api_base = api_base or os.getenv("AZURE_API_BASE")
|
176
|
+
if api_version is None and os.getenv("AZURE_API_VERSION") is None and internal_llm_proxy is None:
|
177
|
+
raise ValueError("API version must be provided for Azure.")
|
178
|
+
litellm.api_version = api_version or os.getenv("AZURE_API_VERSION")
|
162
179
|
else:
|
163
180
|
raise ValueError(f"Provider is not recognized.")
|
164
181
|
|
@@ -189,7 +206,15 @@ class SyntheticDataGeneration:
|
|
189
206
|
kwargs=kwargs
|
190
207
|
)
|
191
208
|
|
209
|
+
def validate_input(self,text):
|
192
210
|
|
211
|
+
if not text.strip():
|
212
|
+
return 'Empty Text provided for qna generation. Please provide valid text'
|
213
|
+
encoding = tiktoken.encoding_for_model("gpt-4")
|
214
|
+
tokens = encoding.encode(text)
|
215
|
+
if len(tokens)<5:
|
216
|
+
return 'Very Small Text provided for qna generation. Please provide longer text'
|
217
|
+
return False
|
193
218
|
|
194
219
|
|
195
220
|
def _get_system_message(self, question_type, n):
|
@@ -274,10 +299,14 @@ class SyntheticDataGeneration:
|
|
274
299
|
# Add optional parameters if they exist in model_config
|
275
300
|
if "api_base" in model_config:
|
276
301
|
completion_params["api_base"] = model_config["api_base"]
|
302
|
+
if "api_version" in model_config:
|
303
|
+
completion_params["api_version"] = model_config["api_version"]
|
277
304
|
if "max_tokens" in model_config:
|
278
305
|
completion_params["max_tokens"] = model_config["max_tokens"]
|
279
306
|
if "temperature" in model_config:
|
280
307
|
completion_params["temperature"] = model_config["temperature"]
|
308
|
+
if 'provider' in model_config:
|
309
|
+
completion_params['model'] = f'{model_config["provider"]}/{model_config["model"]}'
|
281
310
|
|
282
311
|
# Make the API call using LiteLLM
|
283
312
|
try:
|
@@ -318,9 +347,13 @@ class SyntheticDataGeneration:
|
|
318
347
|
list_start_index = data.find('[') # Find the index of the first '['
|
319
348
|
substring_data = data[list_start_index:] if list_start_index != -1 else data # Slice from the list start
|
320
349
|
data = substring_data
|
321
|
-
|
350
|
+
elif provider == "azure":
|
351
|
+
data = response.choices[0].message.content.replace('\n', '')
|
352
|
+
list_start_index = data.find('[') # Find the index of the first '['
|
353
|
+
substring_data = data[list_start_index:] if list_start_index != -1 else data # Slice from the list start
|
354
|
+
data = substring_data
|
322
355
|
else:
|
323
|
-
raise ValueError("Invalid provider. Choose 'groq', 'gemini', or 'openai'.")
|
356
|
+
raise ValueError("Invalid provider. Choose 'groq', 'gemini', 'azure' or 'openai'.")
|
324
357
|
try:
|
325
358
|
json_data = json.loads(data)
|
326
359
|
return pd.DataFrame(json_data)
|
@@ -442,7 +475,7 @@ class SyntheticDataGeneration:
|
|
442
475
|
Returns:
|
443
476
|
list: A list of supported AI providers.
|
444
477
|
"""
|
445
|
-
return ['gemini', 'openai']
|
478
|
+
return ['gemini', 'openai','azure']
|
446
479
|
|
447
480
|
# Usage:
|
448
481
|
# from synthetic_data_generation import SyntheticDataGeneration
|
@@ -48,15 +48,15 @@ class AgentTracerMixin:
|
|
48
48
|
self.auto_instrument_network = False
|
49
49
|
|
50
50
|
def trace_agent(
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
51
|
+
self,
|
52
|
+
name: str,
|
53
|
+
agent_type: str = None,
|
54
|
+
version: str = None,
|
55
|
+
capabilities: List[str] = None,
|
56
|
+
tags: List[str] = [],
|
57
|
+
metadata: Dict[str, Any] = {},
|
58
|
+
metrics: List[Dict[str, Any]] = [],
|
59
|
+
feedback: Optional[Any] = None,
|
60
60
|
):
|
61
61
|
if name not in self.span_attributes_dict:
|
62
62
|
self.span_attributes_dict[name] = SpanAttributes(name)
|
@@ -101,7 +101,10 @@ class AgentTracerMixin:
|
|
101
101
|
original_init = target.__init__
|
102
102
|
|
103
103
|
def wrapped_init(self, *args, **kwargs):
|
104
|
-
|
104
|
+
gt = kwargs.get("gt") if kwargs else None
|
105
|
+
if gt is not None:
|
106
|
+
span = self.span(name)
|
107
|
+
span.add_gt(gt)
|
105
108
|
# Set agent context before initializing
|
106
109
|
component_id = str(uuid.uuid4())
|
107
110
|
hash_id = top_level_hash_id
|
@@ -159,7 +162,10 @@ class AgentTracerMixin:
|
|
159
162
|
@self.file_tracker.trace_decorator
|
160
163
|
@functools.wraps(method)
|
161
164
|
def wrapped_method(self, *args, **kwargs):
|
162
|
-
|
165
|
+
gt = kwargs.get("gt") if kwargs else None
|
166
|
+
if gt is not None:
|
167
|
+
span = tracer.span(name)
|
168
|
+
span.add_gt(gt)
|
163
169
|
# Set this agent as current during method execution
|
164
170
|
token = tracer.current_agent_id.set(
|
165
171
|
self._agent_component_id
|
@@ -193,8 +199,8 @@ class AgentTracerMixin:
|
|
193
199
|
children = tracer.agent_children.get()
|
194
200
|
if children:
|
195
201
|
if (
|
196
|
-
|
197
|
-
|
202
|
+
"children"
|
203
|
+
not in component["data"]
|
198
204
|
):
|
199
205
|
component["data"][
|
200
206
|
"children"
|
@@ -247,6 +253,7 @@ class AgentTracerMixin:
|
|
247
253
|
agent_type,
|
248
254
|
version,
|
249
255
|
capabilities,
|
256
|
+
top_level_hash_id,
|
250
257
|
*args,
|
251
258
|
**kwargs,
|
252
259
|
)
|
@@ -256,10 +263,9 @@ class AgentTracerMixin:
|
|
256
263
|
return decorator
|
257
264
|
|
258
265
|
def _trace_sync_agent_execution(
|
259
|
-
|
266
|
+
self, func, name, agent_type, version, capabilities, top_level_hash_id, *args, **kwargs
|
260
267
|
):
|
261
|
-
|
262
|
-
hash_id = str(uuid.uuid4())
|
268
|
+
hash_id = top_level_hash_id
|
263
269
|
|
264
270
|
"""Synchronous version of agent tracing"""
|
265
271
|
if not self.is_active:
|
@@ -276,6 +282,9 @@ class AgentTracerMixin:
|
|
276
282
|
|
277
283
|
# Extract ground truth if present
|
278
284
|
ground_truth = kwargs.pop("gt", None) if kwargs else None
|
285
|
+
if ground_truth is not None:
|
286
|
+
span = self.span(name)
|
287
|
+
span.add_gt(ground_truth)
|
279
288
|
|
280
289
|
# Get parent agent ID if exists
|
281
290
|
parent_agent_id = self.current_agent_id.get()
|
@@ -293,7 +302,7 @@ class AgentTracerMixin:
|
|
293
302
|
|
294
303
|
try:
|
295
304
|
# Execute the agent
|
296
|
-
result = func(*args, **kwargs)
|
305
|
+
result = self.file_tracker.trace_wrapper(func)(*args, **kwargs)
|
297
306
|
|
298
307
|
# Calculate resource usage
|
299
308
|
end_memory = psutil.Process().memory_info().rss
|
@@ -320,9 +329,6 @@ class AgentTracerMixin:
|
|
320
329
|
children=children,
|
321
330
|
parent_id=parent_agent_id,
|
322
331
|
)
|
323
|
-
# Add ground truth to component data if present
|
324
|
-
if ground_truth is not None:
|
325
|
-
agent_component["data"]["gt"] = ground_truth
|
326
332
|
|
327
333
|
# Add this component as a child to parent's children list
|
328
334
|
parent_children.append(agent_component)
|
@@ -384,7 +390,7 @@ class AgentTracerMixin:
|
|
384
390
|
self.agent_children.reset(children_token)
|
385
391
|
|
386
392
|
async def _trace_agent_execution(
|
387
|
-
|
393
|
+
self, func, name, agent_type, version, capabilities, hash_id, *args, **kwargs
|
388
394
|
):
|
389
395
|
"""Asynchronous version of agent tracing"""
|
390
396
|
if not self.is_active:
|
@@ -399,6 +405,9 @@ class AgentTracerMixin:
|
|
399
405
|
|
400
406
|
# Extract ground truth if present
|
401
407
|
ground_truth = kwargs.pop("gt", None) if kwargs else None
|
408
|
+
if ground_truth is not None:
|
409
|
+
span = self.span(name)
|
410
|
+
span.add_gt(ground_truth)
|
402
411
|
|
403
412
|
# Get parent agent ID if exists
|
404
413
|
parent_agent_id = self.current_agent_id.get()
|
@@ -414,7 +423,7 @@ class AgentTracerMixin:
|
|
414
423
|
|
415
424
|
try:
|
416
425
|
# Execute the agent
|
417
|
-
result = await func(*args, **kwargs)
|
426
|
+
result = await self.file_tracker.trace_wrapper(func)(*args, **kwargs)
|
418
427
|
|
419
428
|
# Calculate resource usage
|
420
429
|
end_memory = psutil.Process().memory_info().rss
|
@@ -441,10 +450,6 @@ class AgentTracerMixin:
|
|
441
450
|
parent_id=parent_agent_id,
|
442
451
|
)
|
443
452
|
|
444
|
-
# Add ground truth to component data if present
|
445
|
-
if ground_truth is not None:
|
446
|
-
agent_component["data"]["gt"] = ground_truth
|
447
|
-
|
448
453
|
# Add this component as a child to parent's children list
|
449
454
|
parent_children.append(agent_component)
|
450
455
|
self.agent_children.set(parent_children)
|
@@ -517,7 +522,7 @@ class AgentTracerMixin:
|
|
517
522
|
for interaction in self.component_user_interaction.get(kwargs["component_id"], []):
|
518
523
|
if interaction["interaction_type"] in ["input", "output"]:
|
519
524
|
input_output_interactions.append(interaction)
|
520
|
-
interactions.extend(input_output_interactions)
|
525
|
+
interactions.extend(input_output_interactions)
|
521
526
|
if self.auto_instrument_file_io:
|
522
527
|
file_io_interactions = []
|
523
528
|
for interaction in self.component_user_interaction.get(kwargs["component_id"], []):
|
@@ -546,9 +551,10 @@ class AgentTracerMixin:
|
|
546
551
|
counter = sum(1 for x in self.visited_metrics if x.startswith(base_metric_name))
|
547
552
|
metric_name = f'{base_metric_name}_{counter}' if counter > 0 else base_metric_name
|
548
553
|
self.visited_metrics.append(metric_name)
|
549
|
-
metric["name"] = metric_name
|
554
|
+
metric["name"] = metric_name
|
550
555
|
metrics.append(metric)
|
551
556
|
|
557
|
+
# TODO agent_trace execute metric
|
552
558
|
component = {
|
553
559
|
"id": kwargs["component_id"],
|
554
560
|
"hash_id": kwargs["hash_id"],
|
@@ -576,8 +582,13 @@ class AgentTracerMixin:
|
|
576
582
|
"interactions": interactions,
|
577
583
|
}
|
578
584
|
|
579
|
-
if self.
|
580
|
-
|
585
|
+
if name in self.span_attributes_dict:
|
586
|
+
span_gt = self.span_attributes_dict[name].gt
|
587
|
+
if span_gt is not None:
|
588
|
+
component["data"]["gt"] = span_gt
|
589
|
+
span_context = self.span_attributes_dict[name].context
|
590
|
+
if span_context:
|
591
|
+
component["data"]["context"] = span_context
|
581
592
|
|
582
593
|
# Reset the SpanAttributes context variable
|
583
594
|
self.span_attributes_dict[kwargs["name"]] = SpanAttributes(kwargs["name"])
|
@@ -599,22 +610,22 @@ class AgentTracerMixin:
|
|
599
610
|
self.component_network_calls.set(component_network_calls)
|
600
611
|
|
601
612
|
def _sanitize_input(self, args: tuple, kwargs: dict) -> dict:
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
+
"""Sanitize and format input data, including handling of nested lists and dictionaries."""
|
614
|
+
|
615
|
+
def sanitize_value(value):
|
616
|
+
if isinstance(value, (int, float, bool, str)):
|
617
|
+
return value
|
618
|
+
elif isinstance(value, list):
|
619
|
+
return [sanitize_value(item) for item in value]
|
620
|
+
elif isinstance(value, dict):
|
621
|
+
return {key: sanitize_value(val) for key, val in value.items()}
|
622
|
+
else:
|
623
|
+
return str(value) # Convert non-standard types to string
|
613
624
|
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
625
|
+
return {
|
626
|
+
"args": [sanitize_value(arg) for arg in args],
|
627
|
+
"kwargs": {key: sanitize_value(val) for key, val in kwargs.items()},
|
628
|
+
}
|
618
629
|
|
619
630
|
def _sanitize_output(self, output: Any) -> Any:
|
620
631
|
"""Sanitize and format output data"""
|
@@ -630,6 +641,6 @@ class AgentTracerMixin:
|
|
630
641
|
|
631
642
|
def instrument_network_calls(self):
|
632
643
|
self.auto_instrument_network = True
|
633
|
-
|
644
|
+
|
634
645
|
def instrument_file_io_calls(self):
|
635
646
|
self.auto_instrument_file_io = True
|